From a6de3e08935bcd090d4e8e1b2de989c6c4034187 Mon Sep 17 00:00:00 2001 From: DrRacket Date: Thu, 7 Dec 2017 07:28:15 +0000 Subject: [PATCH] 3804789168da replay starts --- .../java/hivemall/common/ConversionState.java | 21 +- core/src/main/java/hivemall/fm/Entry.java | 242 ++++++--- .../hivemall/fm/FFMPredictGenericUDAF.java | 262 ++++++++++ .../main/java/hivemall/fm/FFMPredictUDF.java | 187 ------- .../java/hivemall/fm/FFMPredictionModel.java | 349 ------------- .../hivemall/fm/FFMStringFeatureMapModel.java | 315 +++++++----- .../java/hivemall/fm/FMHyperParameters.java | 74 +-- .../hivemall/fm/FMIntFeatureMapModel.java | 6 +- .../hivemall/fm/FMPredictGenericUDAF.java | 15 + .../hivemall/fm/FMStringFeatureMapModel.java | 8 +- .../hivemall/fm/FactorizationMachineUDTF.java | 6 +- core/src/main/java/hivemall/fm/Feature.java | 76 ++- .../FieldAwareFactorizationMachineModel.java | 161 +++++- .../FieldAwareFactorizationMachineUDTF.java | 158 +++--- .../src/main/java/hivemall/fm/IntFeature.java | 6 +- .../ftvec/pairing/FeaturePairsUDTF.java | 155 +++++- .../ftvec/ranking/PositiveOnlyFeedback.java | 8 +- .../ftvec/trans/AddFieldIndicesUDF.java | 89 ++++ .../ftvec/trans/CategoricalFeaturesUDF.java | 121 ++++- .../hivemall/ftvec/trans/FFMFeaturesUDF.java | 47 +- .../ftvec/trans/QuantifiedFeaturesUDTF.java | 7 +- .../ftvec/trans/QuantitativeFeaturesUDF.java | 101 +++- .../ftvec/trans/VectorizeFeaturesUDF.java | 110 +++-- .../java/hivemall/mf/FactorizedModel.java | 18 +- .../model/AbstractPredictionModel.java | 6 +- .../java/hivemall/model/NewSparseModel.java | 2 +- .../main/java/hivemall/model/SparseModel.java | 2 +- .../tools/array/ArrayAvgGenericUDAF.java | 17 +- .../hivemall/utils/buffer/HeapBuffer.java | 37 +- .../maps/Int2FloatOpenHashTable.java | 11 +- .../maps/Int2IntOpenHashTable.java | 9 +- .../collections/maps/Int2LongOpenHashMap.java | 346 +++++++++++++ .../maps/Int2LongOpenHashTable.java | 114 ++--- .../collections/maps/IntOpenHashMap.java | 467 ------------------ .../collections/maps/IntOpenHashTable.java | 142 ++++-- .../maps/Long2DoubleOpenHashTable.java | 9 +- .../maps/Long2FloatOpenHashTable.java | 11 +- .../maps/Long2IntOpenHashTable.java | 9 +- .../utils/collections/maps/OpenHashMap.java | 128 +++-- .../utils/collections/maps/OpenHashTable.java | 12 +- .../java/hivemall/utils/hadoop/HiveUtils.java | 74 ++- .../hivemall/utils/hashing/HashUtils.java | 89 ++++ .../java/hivemall/utils/lang/NumberUtils.java | 68 +++ .../java/hivemall/utils/lang/Primitives.java | 24 - .../java/hivemall/utils/math/MathUtils.java | 33 +- .../hivemall/fm/FFMPredictionModelTest.java | 65 --- .../test/java/hivemall/fm/FeatureTest.java | 7 +- ...ieldAwareFactorizationMachineUDTFTest.java | 66 +-- .../smile/tools/TreePredictUDFv1Test.java | 1 + ...t.java => Int2FloatOpenHashTableTest.java} | 2 +- .../maps/Int2LongOpenHashMapTest.java | 66 +-- .../maps/Int2LongOpenHashTableTest.java | 130 +++++ .../collections/maps/IntOpenHashMapTest.java | 75 --- .../maps/IntOpenHashTableTest.java | 23 + ...st.java => Long2IntOpenHashTableTest.java} | 2 +- docs/gitbook/getting_started/input-format.md | 31 +- pom.xml | 18 + resources/ddl/define-all-as-permanent.hive | 5 +- resources/ddl/define-all.hive | 5 +- resources/ddl/define-all.spark | 5 +- resources/ddl/define-udfs.td.hql | 3 + spark/spark-2.0/pom.xml | 10 +- spark/spark-2.1/pom.xml | 10 +- spark/spark-common/pom.xml | 8 +- 64 files changed, 2750 insertions(+), 1934 deletions(-) create mode 100644 core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java delete mode 100644 core/src/main/java/hivemall/fm/FFMPredictUDF.java delete mode 100644 core/src/main/java/hivemall/fm/FFMPredictionModel.java create mode 100644 core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java create mode 100644 core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java delete mode 100644 core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java create mode 100644 core/src/main/java/hivemall/utils/hashing/HashUtils.java delete mode 100644 core/src/test/java/hivemall/fm/FFMPredictionModelTest.java rename core/src/test/java/hivemall/utils/collections/maps/{Int2FloatOpenHashMapTest.java => Int2FloatOpenHashTableTest.java} (98%) create mode 100644 core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java delete mode 100644 core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java rename core/src/test/java/hivemall/utils/collections/maps/{Long2IntOpenHashMapTest.java => Long2IntOpenHashTableTest.java} (98%) diff --git a/core/src/main/java/hivemall/common/ConversionState.java b/core/src/main/java/hivemall/common/ConversionState.java index 7b5923ff0..435bf7549 100644 --- a/core/src/main/java/hivemall/common/ConversionState.java +++ b/core/src/main/java/hivemall/common/ConversionState.java @@ -99,18 +99,25 @@ public boolean isConverged(final long observedTrainingExamples) { if (changeRate < convergenceRate) { if (readyToFinishIterations) { // NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY - logger.info("Training converged at " + curIter + "-th iteration. [curLosses=" - + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate - + ']'); + if (logger.isInfoEnabled()) { + logger.info("Training converged at " + curIter + "-th iteration. [curLosses=" + + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + + changeRate + ']'); + } return true; } else { + if (logger.isInfoEnabled()) { + logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + + ", #trainingExamples=" + observedTrainingExamples + ']'); + } this.readyToFinishIterations = true; } } else { - if (logger.isDebugEnabled()) { - logger.debug("Iteration #" + curIter + " [curLosses=" + currLosses - + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate - + ", #trainingExamples=" + observedTrainingExamples + ']'); + if (logger.isInfoEnabled()) { + logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + ", prevLosses=" + + prevLosses + ", changeRate=" + changeRate + ", #trainingExamples=" + + observedTrainingExamples + ']'); } this.readyToFinishIterations = false; } diff --git a/core/src/main/java/hivemall/fm/Entry.java b/core/src/main/java/hivemall/fm/Entry.java index 1882f8584..974ab5ba1 100644 --- a/core/src/main/java/hivemall/fm/Entry.java +++ b/core/src/main/java/hivemall/fm/Entry.java @@ -20,17 +20,27 @@ import hivemall.utils.buffer.HeapBuffer; import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.SizeOf; +import hivemall.utils.math.MathUtils; +import java.util.Arrays; + +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; class Entry { @Nonnull protected final HeapBuffer _buf; + @Nonnegative protected final int _size; + @Nonnegative protected final int _factors; + // temporary variables used only in training phase + protected int _key; + @Nonnegative protected long _offset; Entry(@Nonnull HeapBuffer buf, int factors) { @@ -39,128 +49,210 @@ class Entry { this._factors = factors; } - Entry(@Nonnull HeapBuffer buf, int factors, long offset) { - this(buf, factors, Entry.sizeOf(factors), offset); + Entry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) { + this(buf, 1, key, offset); + } + + Entry(@Nonnull HeapBuffer buf, int factors, int key, @Nonnegative long offset) { + this(buf, factors, Entry.sizeOf(factors), key, offset); } - private Entry(@Nonnull HeapBuffer buf, int factors, int size, long offset) { + private Entry(@Nonnull HeapBuffer buf, int factors, int size, int key, @Nonnegative long offset) { this._buf = buf; this._size = size; this._factors = factors; - setOffset(offset); + this._key = key; + this._offset = offset; } - int getSize() { + final int getSize() { return _size; } - long getOffset() { + final int getKey() { + return _key; + } + + final long getOffset() { return _offset; } - void setOffset(long offset) { + final void setOffset(final long offset) { this._offset = offset; } - float getW() { + final float getW() { return _buf.getFloat(_offset); } - void setW(final float value) { + final void setW(final float value) { _buf.putFloat(_offset, value); } - void getV(@Nonnull final float[] Vf) { - final long offset = _offset + SizeOf.FLOAT; + final void getV(@Nonnull final float[] Vf) { + final long offset = _offset; final int len = Vf.length; - for (int i = 0; i < len; i++) { - Vf[i] = _buf.getFloat(offset + SizeOf.FLOAT * i); + for (int f = 0; f < len; f++) { + long index = offset + SizeOf.FLOAT * f; + Vf[f] = _buf.getFloat(index); } } - void setV(@Nonnull final float[] Vf) { - final long offset = _offset + SizeOf.FLOAT; + final void setV(@Nonnull final float[] Vf) { + final long offset = _offset; final int len = Vf.length; - for (int i = 0; i < len; i++) { - _buf.putFloat(offset + SizeOf.FLOAT * i, Vf[i]); + for (int f = 0; f < len; f++) { + long index = offset + SizeOf.FLOAT * f; + _buf.putFloat(index, Vf[f]); } } - float getV(final int f) { - return _buf.getFloat(_offset + SizeOf.FLOAT + SizeOf.FLOAT * f); + final float getV(final int f) { + long index = _offset + SizeOf.FLOAT * f; + return _buf.getFloat(index); } - void setV(final int f, final float value) { - long index = _offset + SizeOf.FLOAT + SizeOf.FLOAT * f; + final void setV(final int f, final float value) { + long index = _offset + SizeOf.FLOAT * f; _buf.putFloat(index, value); } - double getSumOfSquaredGradientsV() { + double getSumOfSquaredGradients(@Nonnegative int f) { throw new UnsupportedOperationException(); } - void addGradientV(float grad) { + void addGradient(@Nonnegative int f, float grad) { throw new UnsupportedOperationException(); } - float updateZ(float gradW, float alpha) { + final float updateZ(final float gradW, final float alpha) { + float w = getW(); + return updateZ(0, w, gradW, alpha); + } + + float updateZ(@Nonnegative int f, float W, float gradW, float alpha) { throw new UnsupportedOperationException(); } - double updateN(float gradW) { + final double updateN(final float gradW) { + return updateN(0, gradW); + } + + double updateN(@Nonnegative int f, float gradW) { throw new UnsupportedOperationException(); } - static int sizeOf(int factors) { - return SizeOf.FLOAT + SizeOf.FLOAT * factors; + boolean removable() { + if (!isEntryW(_key)) {// entry for V + final long offset = _offset; + for (int f = 0; f < _factors; f++) { + final float Vf = _buf.getFloat(offset + SizeOf.FLOAT * f); + if (!MathUtils.closeToZero(Vf, 1E-9f)) { + return false; + } + } + } + return true; + } + + void clear() {}; + + static int sizeOf(@Nonnegative final int factors) { + Preconditions.checkArgument(factors >= 1, "Factors must be greather than 0: " + factors); + return SizeOf.FLOAT * factors; + } + + static boolean isEntryW(final int i) { + return i < 0; + } + + @Override + public String toString() { + if (Entry.isEntryW(_key)) { + return "W=" + getW(); + } else { + float[] Vf = new float[_factors]; + getV(Vf); + return "V=" + Arrays.toString(Vf); + } } - static class AdaGradEntry extends Entry { + static final class AdaGradEntry extends Entry { final long _gg_offset; - AdaGradEntry(@Nonnull HeapBuffer buf, int factors, long offset) { - super(buf, factors, AdaGradEntry.sizeOf(factors), offset); - this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors; + AdaGradEntry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) { + this(buf, 1, key, offset); } - private AdaGradEntry(@Nonnull HeapBuffer buf, int factors, int size, long offset) { - super(buf, factors, size, offset); - this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors; + AdaGradEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, + @Nonnegative long offset) { + super(buf, factors, AdaGradEntry.sizeOf(factors), key, offset); + this._gg_offset = _offset + Entry.sizeOf(factors); } @Override - double getSumOfSquaredGradientsV() { - return _buf.getDouble(_gg_offset); + double getSumOfSquaredGradients(@Nonnegative final int f) { + Preconditions.checkArgument(f >= 0); + + long offset = _gg_offset + SizeOf.DOUBLE * f; + return _buf.getDouble(offset); } @Override - void addGradientV(float grad) { - double v = _buf.getDouble(_gg_offset); + void addGradient(@Nonnegative final int f, final float grad) { + Preconditions.checkArgument(f >= 0); + + long offset = _gg_offset + SizeOf.DOUBLE * f; + double v = _buf.getDouble(offset); v += grad * grad; - _buf.putDouble(_gg_offset, v); + _buf.putDouble(offset, v); } - static int sizeOf(int factors) { - return Entry.sizeOf(factors) + SizeOf.DOUBLE; + @Override + void clear() { + for (int f = 0; f < _factors; f++) { + long offset = _gg_offset + SizeOf.DOUBLE * f; + _buf.putDouble(offset, 0.d); + } + } + + static int sizeOf(@Nonnegative final int factors) { + return Entry.sizeOf(factors) + SizeOf.DOUBLE * factors; + } + + @Override + public String toString() { + final double[] gg = new double[_factors]; + for (int f = 0; f < _factors; f++) { + gg[f] = getSumOfSquaredGradients(f); + } + return super.toString() + ", gg=" + Arrays.toString(gg); } } - static final class FTRLEntry extends AdaGradEntry { + static final class FTRLEntry extends Entry { final long _z_offset; - FTRLEntry(@Nonnull HeapBuffer buf, int factors, long offset) { - super(buf, factors, FTRLEntry.sizeOf(factors), offset); - this._z_offset = _gg_offset + SizeOf.DOUBLE; + FTRLEntry(@Nonnull HeapBuffer buf, int key, long offset) { + this(buf, 1, key, offset); + } + + FTRLEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, long offset) { + super(buf, factors, FTRLEntry.sizeOf(factors), key, offset); + this._z_offset = _offset + Entry.sizeOf(factors); } @Override - float updateZ(float gradW, float alpha) { - final float W = getW(); - final float z = getZ(); - final double n = getN(); + float updateZ(final int f, final float W, final float gradW, final float alpha) { + Preconditions.checkArgument(f >= 0); + + final long zOffset = offsetZ(f); + + final float z = _buf.getFloat(zOffset); + final double n = _buf.getFloat(offsetN(f)); // implicit cast to float double gg = gradW * gradW; float sigma = (float) ((Math.sqrt(n + gg) - Math.sqrt(n)) / alpha); @@ -171,44 +263,56 @@ float updateZ(float gradW, float alpha) { + gradW + ", sigma=" + sigma + ", W=" + W + ", n=" + n + ", gg=" + gg + ", alpha=" + alpha); } - setZ(newZ); + _buf.putFloat(zOffset, newZ); return newZ; } - private float getZ() { - return _buf.getFloat(_z_offset); - } - - private void setZ(final float value) { - _buf.putFloat(_z_offset, value); - } - @Override - double updateN(final float gradW) { - final double n = getN(); + double updateN(final int f, final float gradW) { + Preconditions.checkArgument(f >= 0); + + final long nOffset = offsetN(f); + final double n = _buf.getFloat(nOffset); final double newN = n + gradW * gradW; if (!NumberUtils.isFinite(newN)) { throw new IllegalStateException("Got newN " + newN + " where n=" + n + ", gradW=" + gradW); } - setN(newN); + _buf.putFloat(nOffset, NumberUtils.castToFloat(newN)); // cast may throw ArithmeticException return newN; } - private double getN() { - long index = _z_offset + SizeOf.FLOAT; - return _buf.getDouble(index); + private long offsetZ(@Nonnegative final int f) { + return _z_offset + SizeOf.FLOAT * f; } - private void setN(final double value) { - long index = _z_offset + SizeOf.FLOAT; - _buf.putDouble(index, value); + private long offsetN(@Nonnegative final int f) { + return _z_offset + SizeOf.FLOAT * (_factors + f); } - static int sizeOf(int factors) { - return AdaGradEntry.sizeOf(factors) + SizeOf.FLOAT + SizeOf.DOUBLE; + @Override + void clear() { + for (int f = 0; f < _factors; f++) { + _buf.putFloat(offsetZ(f), 0.f); + _buf.putFloat(offsetN(f), 0.f); + } } + static int sizeOf(@Nonnegative final int factors) { + return Entry.sizeOf(factors) + (SizeOf.FLOAT + SizeOf.FLOAT) * factors; + } + + @Override + public String toString() { + final float[] Z = new float[_factors]; + final float[] N = new float[_factors]; + for (int f = 0; f < _factors; f++) { + Z[f] = _buf.getFloat(offsetZ(f)); + N[f] = _buf.getFloat(offsetN(f)); + } + return super.toString() + ", Z=" + Arrays.toString(Z) + ", N=" + Arrays.toString(N); + } } + } diff --git a/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java new file mode 100644 index 000000000..7cbd6889e --- /dev/null +++ b/core/src/main/java/hivemall/fm/FFMPredictGenericUDAF.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.fm; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.SizeOf; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; + +@Description(name = "ffm_predict", + value = "_FUNC_(float Wi, array Vifj, array Vjfi, float Xi, float Xj)" + + " - Returns a prediction value in Double") +public final class FFMPredictGenericUDAF extends AbstractGenericUDAFResolver { + + private FFMPredictGenericUDAF() {} + + @Override + public Evaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 5) { + throw new UDFArgumentLengthException( + "Expected argument length is 5 but given argument length was " + typeInfo.length); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, + "Number type is expected for the first argument Wi: " + typeInfo[0].getTypeName()); + } + if (typeInfo[1].getCategory() != Category.LIST) { + throw new UDFArgumentTypeException(1, + "List type is expected for the second argument Vifj: " + typeInfo[1].getTypeName()); + } + if (typeInfo[2].getCategory() != Category.LIST) { + throw new UDFArgumentTypeException(2, + "List type is expected for the third argument Vjfi: " + typeInfo[2].getTypeName()); + } + ListTypeInfo typeInfo1 = (ListTypeInfo) typeInfo[1]; + if (!HiveUtils.isFloatingPointTypeInfo(typeInfo1.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(1, + "Double or Float type is expected for the element type of list Vifj: " + + typeInfo1.getTypeName()); + } + ListTypeInfo typeInfo2 = (ListTypeInfo) typeInfo[2]; + if (!HiveUtils.isFloatingPointTypeInfo(typeInfo2.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(2, + "Double or Float type is expected for the element type of list Vjfi: " + + typeInfo1.getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) { + throw new UDFArgumentTypeException(3, + "Number type is expected for the third argument Xi: " + typeInfo[3].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[4])) { + throw new UDFArgumentTypeException(4, + "Number type is expected for the third argument Xi: " + typeInfo[4].getTypeName()); + } + return new Evaluator(); + } + + public static final class Evaluator extends GenericUDAFEvaluator { + + // input OI + private PrimitiveObjectInspector wiOI; + private ListObjectInspector vijOI, vjiOI; + private PrimitiveObjectInspector vijElemOI, vjiElemOI; + private PrimitiveObjectInspector xiOI, xjOI; + + // merge input OI + private DoubleObjectInspector mergeInputOI; + + public Evaluator() {} + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 5); + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.wiOI = HiveUtils.asDoubleCompatibleOI(parameters[0]); + this.vijOI = HiveUtils.asListOI(parameters[1]); + this.vijElemOI = HiveUtils.asFloatingPointOI(vijOI.getListElementObjectInspector()); + this.vjiOI = HiveUtils.asListOI(parameters[2]); + this.vjiElemOI = HiveUtils.asFloatingPointOI(vjiOI.getListElementObjectInspector()); + this.xiOI = HiveUtils.asDoubleCompatibleOI(parameters[3]); + this.xjOI = HiveUtils.asDoubleCompatibleOI(parameters[4]); + } else {// from partial aggregation + this.mergeInputOI = HiveUtils.asDoubleOI(parameters[0]); + } + + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + + @Override + public FFMPredictAggregationBuffer getNewAggregationBuffer() throws HiveException { + FFMPredictAggregationBuffer myAggr = new FFMPredictAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + myAggr.reset(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + final FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + + if (parameters[0] == null) {// Wi is null + if (parameters[3] == null || parameters[4] == null) { + // both Xi and Xj are nonnull => Xi Xj + return; + } + if (parameters[1] == null || parameters[2] == null) { + // vi, vj can be null where feature index does not exist in the prediction model + return; + } + + // (i, j, xi, xj) => (wi, vi, vj, xi, xj) + float[] vij = HiveUtils.asFloatArray(parameters[1], vijOI, vijElemOI, false); + float[] vji = HiveUtils.asFloatArray(parameters[2], vjiOI, vjiElemOI, false); + double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI); + double xj = PrimitiveObjectInspectorUtils.getDouble(parameters[4], xjOI); + + myAggr.addViVjXiXj(vij, vji, xi, xj); + } else { + final double wi = PrimitiveObjectInspectorUtils.getDouble(parameters[0], wiOI); + + if (parameters[3] == null && parameters[4] == null) {// Xi and Xj are null => global bias `w0` + // (i=0, j=null, xi=null, xj=null) => (wi, vi=?, vj=null, xi=null, xj=null) + myAggr.addW0(wi); + } else if (parameters[4] == null) {// Only Xi is nonnull => linear combination `wi` * `xi` + // (i, j=null, xi, xj=null) => (wi, vi, vj=null, xi, xj=null) + double xi = PrimitiveObjectInspectorUtils.getDouble(parameters[3], xiOI); + myAggr.addWiXi(wi, xi); + } + } + } + + @Override + public DoubleWritable terminatePartial( + @SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + double sum = myAggr.get(); + return new DoubleWritable(sum); + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + double sum = mergeInputOI.get(partial); + myAggr.merge(sum); + } + + @Override + public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + FFMPredictAggregationBuffer myAggr = (FFMPredictAggregationBuffer) agg; + double result = myAggr.get(); + return new DoubleWritable(result); + } + + } + + @AggregationType(estimable = true) + public static final class FFMPredictAggregationBuffer extends AbstractAggregationBuffer { + + private double sum; + + FFMPredictAggregationBuffer() { + super(); + } + + void reset() { + this.sum = 0.d; + } + + void merge(double o_sum) { + this.sum += o_sum; + } + + double get() { + return sum; + } + + void addW0(final double W0) { + this.sum += W0; + } + + void addWiXi(final double Wi, final double Xi) { + this.sum += (Wi * Xi); + } + + void addViVjXiXj(@Nonnull final float[] Vij, @Nonnull final float[] Vji, final double Xi, + final double Xj) throws UDFArgumentException { + if (Vij.length != Vji.length) { + throw new UDFArgumentException("Mismatch in the number of factors"); + } + + final int factors = Vij.length; + + // compute inner product + double prod = 0.d; + for (int f = 0; f < factors; f++) { + prod += (Vij[f] * Vji[f]); + } + + this.sum += (prod * Xi * Xj); + } + + @Override + public int estimate() { + return SizeOf.DOUBLE; + } + + } + +} diff --git a/core/src/main/java/hivemall/fm/FFMPredictUDF.java b/core/src/main/java/hivemall/fm/FFMPredictUDF.java deleted file mode 100644 index 48745d948..000000000 --- a/core/src/main/java/hivemall/fm/FFMPredictUDF.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.annotations.Experimental; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.lang.NumberUtils; - -import java.io.IOException; -import java.util.Arrays; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; -import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; -import org.apache.hadoop.io.Text; - -/** - * @since v0.5-rc.1 - */ -@Description(name = "ffm_predict", - value = "_FUNC_(string modelId, string model, array features)" - + " returns a prediction result in double from a Field-aware Factorization Machine") -@UDFType(deterministic = true, stateful = false) -@Experimental -public final class FFMPredictUDF extends GenericUDF { - - private StringObjectInspector _modelIdOI; - private StringObjectInspector _modelOI; - private ListObjectInspector _featureListOI; - - private DoubleWritable _result; - @Nullable - private String _cachedModeId; - @Nullable - private FFMPredictionModel _cachedModel; - @Nullable - private Feature[] _probes; - - public FFMPredictUDF() {} - - @Override - public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if (argOIs.length != 3) { - throw new UDFArgumentException("_FUNC_ takes 3 arguments"); - } - this._modelIdOI = HiveUtils.asStringOI(argOIs[0]); - this._modelOI = HiveUtils.asStringOI(argOIs[1]); - this._featureListOI = HiveUtils.asListOI(argOIs[2]); - - this._result = new DoubleWritable(); - return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; - } - - @Override - public Object evaluate(DeferredObject[] args) throws HiveException { - String modelId = _modelIdOI.getPrimitiveJavaObject(args[0].get()); - if (modelId == null) { - throw new HiveException("modelId is not set"); - } - - final FFMPredictionModel model; - if (modelId.equals(_cachedModeId)) { - model = this._cachedModel; - } else { - Text serModel = _modelOI.getPrimitiveWritableObject(args[1].get()); - if (serModel == null) { - throw new HiveException("Model is null for model ID: " + modelId); - } - byte[] b = serModel.getBytes(); - final int length = serModel.getLength(); - try { - model = FFMPredictionModel.deserialize(b, length); - b = null; - } catch (ClassNotFoundException e) { - throw new HiveException(e); - } catch (IOException e) { - throw new HiveException(e); - } - this._cachedModeId = modelId; - this._cachedModel = model; - } - - int numFeatures = model.getNumFeatures(); - int numFields = model.getNumFields(); - - Object arg2 = args[2].get(); - // [workaround] - // java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray - // cannot be cast to [Ljava.lang.Object; - if (arg2 instanceof LazyBinaryArray) { - arg2 = ((LazyBinaryArray) arg2).getList(); - } - Feature[] x = Feature.parseFFMFeatures(arg2, _featureListOI, _probes, numFeatures, - numFields); - if (x == null || x.length == 0) { - return null; // return NULL if there are no features - } - this._probes = x; - - double predicted = predict(x, model); - _result.set(predicted); - return _result; - } - - private static double predict(@Nonnull final Feature[] x, - @Nonnull final FFMPredictionModel model) throws HiveException { - // w0 - double ret = model.getW0(); - // W - for (Feature e : x) { - double xi = e.getValue(); - float wi = model.getW(e); - double wx = wi * xi; - ret += wx; - } - // V - final int factors = model.getNumFactors(); - final float[] vij = new float[factors]; - final float[] vji = new float[factors]; - for (int i = 0; i < x.length; ++i) { - final Feature ei = x[i]; - final double xi = ei.getValue(); - final int iField = ei.getField(); - for (int j = i + 1; j < x.length; ++j) { - final Feature ej = x[j]; - final double xj = ej.getValue(); - final int jField = ej.getField(); - if (!model.getV(ei, jField, vij)) { - continue; - } - if (!model.getV(ej, iField, vji)) { - continue; - } - for (int f = 0; f < factors; f++) { - float vijf = vij[f]; - float vjif = vji[f]; - ret += vijf * vjif * xi * xj; - } - } - } - if (!NumberUtils.isFinite(ret)) { - throw new HiveException("Detected " + ret + " in ffm_predict"); - } - return ret; - } - - @Override - public void close() throws IOException { - super.close(); - // clean up to help GC - this._cachedModel = null; - this._probes = null; - } - - @Override - public String getDisplayString(String[] args) { - return "ffm_predict(" + Arrays.toString(args) + ")"; - } - -} diff --git a/core/src/main/java/hivemall/fm/FFMPredictionModel.java b/core/src/main/java/hivemall/fm/FFMPredictionModel.java deleted file mode 100644 index befbec9d0..000000000 --- a/core/src/main/java/hivemall/fm/FFMPredictionModel.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.codec.VariableByteCodec; -import hivemall.utils.codec.ZigZagLEB128Codec; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashTable; -import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm; -import hivemall.utils.io.IOUtils; -import hivemall.utils.lang.ArrayUtils; -import hivemall.utils.lang.HalfFloat; -import hivemall.utils.lang.ObjectUtils; - -import java.io.DataInput; -import java.io.DataOutput; -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.Arrays; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -public final class FFMPredictionModel implements Externalizable { - private static final Log LOG = LogFactory.getLog(FFMPredictionModel.class); - - private static final byte HALF_FLOAT_ENTRY = 1; - private static final byte W_ONLY_HALF_FLOAT_ENTRY = 2; - private static final byte FLOAT_ENTRY = 3; - private static final byte W_ONLY_FLOAT_ENTRY = 4; - - /** - * maps feature to feature weight pointer - */ - private Int2LongOpenHashTable _map; - private HeapBuffer _buf; - - private double _w0; - private int _factors; - private int _numFeatures; - private int _numFields; - - public FFMPredictionModel() {}// for Externalizable - - public FFMPredictionModel(@Nonnull Int2LongOpenHashTable map, @Nonnull HeapBuffer buf, - double w0, int factor, int numFeatures, int numFields) { - this._map = map; - this._buf = buf; - this._w0 = w0; - this._factors = factor; - this._numFeatures = numFeatures; - this._numFields = numFields; - } - - public int getNumFactors() { - return _factors; - } - - public double getW0() { - return _w0; - } - - public int getNumFeatures() { - return _numFeatures; - } - - public int getNumFields() { - return _numFields; - } - - public int getActualNumFeatures() { - return _map.size(); - } - - public long approxBytesConsumed() { - int size = _map.size(); - - // [map] size * (|state| + |key| + |entry|) - long bytes = size * (1L + 4L + 4L + (4L * _factors)); - int rest = _map.capacity() - size; - if (rest > 0) { - bytes += rest * 1L; - } - // w0, factors, numFeatures, numFields, used, size - bytes += (8 + 4 + 4 + 4 + 4 + 4); - return bytes; - } - - @Nullable - private Entry getEntry(final int key) { - final long ptr = _map.get(key); - if (ptr == -1L) { - return null; - } - return new Entry(_buf, _factors, ptr); - } - - public float getW(@Nonnull final Feature x) { - int j = x.getFeatureIndex(); - - Entry entry = getEntry(j); - if (entry == null) { - return 0.f; - } - return entry.getW(); - } - - /** - * @return true if V exists - */ - public boolean getV(@Nonnull final Feature x, @Nonnull final int yField, @Nonnull float[] dst) { - int j = Feature.toIntFeature(x, yField, _numFields); - - Entry entry = getEntry(j); - if (entry == null) { - return false; - } - - entry.getV(dst); - if (ArrayUtils.equals(dst, 0.f)) { - return false; // treat as null - } - return true; - } - - @Override - public void writeExternal(@Nonnull ObjectOutput out) throws IOException { - out.writeDouble(_w0); - final int factors = _factors; - out.writeInt(factors); - out.writeInt(_numFeatures); - out.writeInt(_numFields); - - int used = _map.size(); - out.writeInt(used); - - final int[] keys = _map.getKeys(); - final int size = keys.length; - out.writeInt(size); - - final byte[] states = _map.getStates(); - writeStates(states, out); - - final long[] values = _map.getValues(); - - final HeapBuffer buf = _buf; - final Entry e = new Entry(buf, factors); - final float[] Vf = new float[factors]; - for (int i = 0; i < size; i++) { - if (states[i] != IntOpenHashTable.FULL) { - continue; - } - ZigZagLEB128Codec.writeSignedInt(keys[i], out); - e.setOffset(values[i]); - writeEntry(e, factors, Vf, out); - } - - // help GC - this._map = null; - this._buf = null; - } - - private static void writeEntry(@Nonnull final Entry e, final int factors, - @Nonnull final float[] Vf, @Nonnull final DataOutput out) throws IOException { - final float W = e.getW(); - e.getV(Vf); - - if (ArrayUtils.almostEquals(Vf, 0.f)) { - if (HalfFloat.isRepresentable(W)) { - out.writeByte(W_ONLY_HALF_FLOAT_ENTRY); - out.writeShort(HalfFloat.floatToHalfFloat(W)); - } else { - out.writeByte(W_ONLY_FLOAT_ENTRY); - out.writeFloat(W); - } - } else if (isRepresentableAsHalfFloat(W, Vf)) { - out.writeByte(HALF_FLOAT_ENTRY); - out.writeShort(HalfFloat.floatToHalfFloat(W)); - for (int i = 0; i < factors; i++) { - out.writeShort(HalfFloat.floatToHalfFloat(Vf[i])); - } - } else { - out.writeByte(FLOAT_ENTRY); - out.writeFloat(W); - IOUtils.writeFloats(Vf, factors, out); - } - } - - private static boolean isRepresentableAsHalfFloat(final float W, @Nonnull final float[] Vf) { - if (!HalfFloat.isRepresentable(W)) { - return false; - } - for (float V : Vf) { - if (!HalfFloat.isRepresentable(V)) { - return false; - } - } - return true; - } - - @Nonnull - static void writeStates(@Nonnull final byte[] status, @Nonnull final DataOutput out) - throws IOException { - // write empty states's indexes differentially - final int size = status.length; - int cardinarity = 0; - for (int i = 0; i < size; i++) { - if (status[i] != IntOpenHashTable.FULL) { - cardinarity++; - } - } - out.writeInt(cardinarity); - if (cardinarity == 0) { - return; - } - int prev = 0; - for (int i = 0; i < size; i++) { - if (status[i] != IntOpenHashTable.FULL) { - int diff = i - prev; - assert (diff >= 0); - VariableByteCodec.encodeUnsignedInt(diff, out); - prev = i; - } - } - } - - @Override - public void readExternal(@Nonnull final ObjectInput in) throws IOException, - ClassNotFoundException { - this._w0 = in.readDouble(); - final int factors = in.readInt(); - this._factors = factors; - this._numFeatures = in.readInt(); - this._numFields = in.readInt(); - - final int used = in.readInt(); - final int size = in.readInt(); - final int[] keys = new int[size]; - final long[] values = new long[size]; - final byte[] states = new byte[size]; - readStates(in, states); - - final int entrySize = Entry.sizeOf(factors); - int numChunks = (entrySize * used) / HeapBuffer.DEFAULT_CHUNK_BYTES + 1; - final HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE, numChunks); - final Entry e = new Entry(buf, factors); - final float[] Vf = new float[factors]; - for (int i = 0; i < size; i++) { - if (states[i] != IntOpenHashTable.FULL) { - continue; - } - keys[i] = ZigZagLEB128Codec.readSignedInt(in); - long ptr = buf.allocate(entrySize); - e.setOffset(ptr); - readEntry(in, factors, Vf, e); - values[i] = ptr; - } - - this._map = new Int2LongOpenHashTable(keys, values, states, used); - this._buf = buf; - } - - @Nonnull - private static void readEntry(@Nonnull final DataInput in, final int factors, - @Nonnull final float[] Vf, @Nonnull Entry dst) throws IOException { - final byte type = in.readByte(); - switch (type) { - case HALF_FLOAT_ENTRY: { - float W = HalfFloat.halfFloatToFloat(in.readShort()); - dst.setW(W); - for (int i = 0; i < factors; i++) { - Vf[i] = HalfFloat.halfFloatToFloat(in.readShort()); - } - dst.setV(Vf); - break; - } - case W_ONLY_HALF_FLOAT_ENTRY: { - float W = HalfFloat.halfFloatToFloat(in.readShort()); - dst.setW(W); - break; - } - case FLOAT_ENTRY: { - float W = in.readFloat(); - dst.setW(W); - IOUtils.readFloats(in, Vf); - dst.setV(Vf); - break; - } - case W_ONLY_FLOAT_ENTRY: { - float W = in.readFloat(); - dst.setW(W); - break; - } - default: - throw new IOException("Unexpected Entry type: " + type); - } - } - - @Nonnull - static void readStates(@Nonnull final DataInput in, @Nonnull final byte[] status) - throws IOException { - // read non-empty states differentially - final int cardinarity = in.readInt(); - Arrays.fill(status, IntOpenHashTable.FULL); - int prev = 0; - for (int j = 0; j < cardinarity; j++) { - int i = VariableByteCodec.decodeUnsignedInt(in) + prev; - status[i] = IntOpenHashTable.FREE; - prev = i; - } - } - - public byte[] serialize() throws IOException { - LOG.info("FFMPredictionModel#serialize(): " + _buf.toString()); - return ObjectUtils.toCompressedBytes(this, CompressionAlgorithm.lzma2, true); - } - - public static FFMPredictionModel deserialize(@Nonnull final byte[] serializedObj, final int len) - throws ClassNotFoundException, IOException { - FFMPredictionModel model = new FFMPredictionModel(); - ObjectUtils.readCompressedObject(serializedObj, len, model, CompressionAlgorithm.lzma2, - true); - LOG.info("FFMPredictionModel#deserialize(): " + model._buf.toString()); - return model; - } - -} diff --git a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java index 4f445fa68..22b05418a 100644 --- a/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FFMStringFeatureMapModel.java @@ -22,13 +22,20 @@ import hivemall.fm.Entry.FTRLEntry; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.buffer.HeapBuffer; +import hivemall.utils.collections.lists.LongArrayList; import hivemall.utils.collections.maps.Int2LongOpenHashTable; +import hivemall.utils.collections.maps.Int2LongOpenHashTable.MapIterator; import hivemall.utils.lang.NumberUtils; -import hivemall.utils.math.MathUtils; +import java.text.NumberFormat; +import java.util.Locale; + +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.roaringbitmap.RoaringBitmap; + public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachineModel { private static final int DEFAULT_MAPSIZE = 65536; @@ -36,37 +43,55 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi private float _w0; @Nonnull private final Int2LongOpenHashTable _map; + @Nonnull private final HeapBuffer _buf; + @Nonnull + private final LongArrayList _freelistW; + @Nonnull + private final LongArrayList _freelistV; + + private boolean _initV; + @Nonnull + private RoaringBitmap _removedV; + // hyperparams - private final int _numFeatures; private final int _numFields; - // FTEL - private final float _alpha; - private final float _beta; - private final float _lambda1; - private final float _lamdda2; + private final int _entrySizeW; + private final int _entrySizeV; - private final int _entrySize; + // statistics + private long _bytesAllocated, _bytesUsed; + private int _numAllocatedW, _numReusedW, _numRemovedW; + private int _numAllocatedV, _numReusedV, _numRemovedV; public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) { super(params); this._w0 = 0.f; this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE); this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE); - this._numFeatures = params.numFeatures; + this._freelistW = new LongArrayList(); + this._freelistV = new LongArrayList(); + this._initV = true; + this._removedV = new RoaringBitmap(); this._numFields = params.numFields; - this._alpha = params.alphaFTRL; - this._beta = params.betaFTRL; - this._lambda1 = params.lambda1; - this._lamdda2 = params.lamdda2; - this._entrySize = entrySize(_factor, _useFTRL, _useAdaGrad); + this._entrySizeW = entrySize(1, _useFTRL, _useAdaGrad); + this._entrySizeV = entrySize(_factor, _useFTRL, _useAdaGrad); } - @Nonnull - FFMPredictionModel toPredictionModel() { - return new FFMPredictionModel(_map, _buf, _w0, _factor, _numFeatures, _numFields); + private static int entrySize(@Nonnegative int factors, boolean ftrl, boolean adagrad) { + if (ftrl) { + return FTRLEntry.sizeOf(factors); + } else if (adagrad) { + return AdaGradEntry.sizeOf(factors); + } else { + return Entry.sizeOf(factors); + } + } + + void disableInitV() { + this._initV = false; } @Override @@ -86,7 +111,7 @@ protected void setW0(float nextW0) { @Override public float getW(@Nonnull final Feature x) { - int j = x.getFeatureIndex(); + int j = Feature.toIntFeature(x); Entry entry = getEntry(j); if (entry == null) { @@ -97,12 +122,11 @@ public float getW(@Nonnull final Feature x) { @Override protected void setW(@Nonnull final Feature x, final float nextWi) { - final int j = x.getFeatureIndex(); + final int j = Feature.toIntFeature(x); Entry entry = getEntry(j); if (entry == null) { - float[] V = initV(); - entry = newEntry(nextWi, V); + entry = newEntry(j, nextWi); long ptr = entry.getOffset(); _map.put(j, ptr); } else { @@ -110,53 +134,6 @@ protected void setW(@Nonnull final Feature x, final float nextWi) { } } - @Override - void updateWi(final double dloss, @Nonnull final Feature x, final float eta) { - final double Xi = x.getValue(); - float gradWi = (float) (dloss * Xi); - - final Entry theta = getEntry(x); - float wi = theta.getW(); - - float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi); - if (!NumberUtils.isFinite(nextWi)) { - throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() - + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss - + ", eta=" + eta); - } - theta.setW(nextWi); - } - - /** - * Update Wi using Follow-the-Regularized-Leader - */ - boolean updateWiFTRL(final double dloss, @Nonnull final Feature x, final float eta) { - final double Xi = x.getValue(); - float gradWi = (float) (dloss * Xi); - - final Entry theta = getEntry(x); - float wi = theta.getW(); - - final float z = theta.updateZ(gradWi, _alpha); - final double n = theta.updateN(gradWi); - - if (Math.abs(z) <= _lambda1) { - removeEntry(x); - return wi != 0; - } - - final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n)) - / _alpha + _lamdda2)); - if (!NumberUtils.isFinite(nextWi)) { - throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() - + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss - + ", eta=" + eta + ", n=" + n + ", z=" + z); - } - theta.setW(nextWi); - return (nextWi != 0) || (wi != 0); - } - - /** * @return V_x,yField,f */ @@ -166,10 +143,16 @@ public float getV(@Nonnull final Feature x, @Nonnull final int yField, final int Entry entry = getEntry(j); if (entry == null) { + if (_initV == false) { + return 0.f; + } else if (_removedV.contains(j)) { + return 0.f; + } float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, V); long ptr = entry.getOffset(); _map.put(j, ptr); + return V[f]; } return entry.getV(f); } @@ -181,8 +164,13 @@ protected void setV(@Nonnull final Feature x, @Nonnull final int yField, final i Entry entry = getEntry(j); if (entry == null) { + if (_initV == false) { + return; + } else if (_removedV.contains(j)) { + return; + } float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, V); long ptr = entry.getOffset(); _map.put(j, ptr); } @@ -190,13 +178,12 @@ protected void setV(@Nonnull final Feature x, @Nonnull final int yField, final i } @Override - protected Entry getEntry(@Nonnull final Feature x) { - final int j = x.getFeatureIndex(); + protected Entry getEntryW(@Nonnull final Feature x) { + final int j = Feature.toIntFeature(x); Entry entry = getEntry(j); if (entry == null) { - float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, 0.f); long ptr = entry.getOffset(); _map.put(j, ptr); } @@ -204,51 +191,92 @@ protected Entry getEntry(@Nonnull final Feature x) { } @Override - protected Entry getEntry(@Nonnull final Feature x, @Nonnull final int yField) { + protected Entry getEntryV(@Nonnull final Feature x, @Nonnull final int yField) { final int j = Feature.toIntFeature(x, yField, _numFields); Entry entry = getEntry(j); if (entry == null) { + if (_initV == false) { + return null; + } else if (_removedV.contains(j)) { + return null; + } float[] V = initV(); - entry = newEntry(V); + entry = newEntry(j, V); long ptr = entry.getOffset(); _map.put(j, ptr); } return entry; } - protected void removeEntry(@Nonnull final Feature x) { - int j = x.getFeatureIndex(); - _map.remove(j); + @Override + protected void removeEntry(@Nonnull final Entry entry) { + final int j = entry.getKey(); + final long ptr = _map.remove(j); + if (ptr == -1L) { + return; // should never be happen. + } + entry.clear(); + if (Entry.isEntryW(j)) { + _freelistW.add(ptr); + this._numRemovedW++; + this._bytesUsed -= _entrySizeW; + } else { + _removedV.add(j); + _freelistV.add(ptr); + this._numRemovedV++; + this._bytesUsed -= _entrySizeV; + } } @Nonnull - protected final Entry newEntry(final float W, @Nonnull final float[] V) { - Entry entry = newEntry(); - entry.setW(W); - entry.setV(V); - return entry; - } + protected final Entry newEntry(final int key, final float W) { + final long ptr; + if (_freelistW.isEmpty()) { + ptr = _buf.allocate(_entrySizeW); + this._numAllocatedW++; + this._bytesAllocated += _entrySizeW; + this._bytesUsed += _entrySizeW; + } else {// reuse removed entry + ptr = _freelistW.remove(); + this._numReusedW++; + } + final Entry entry; + if (_useFTRL) { + entry = new FTRLEntry(_buf, key, ptr); + } else if (_useAdaGrad) { + entry = new AdaGradEntry(_buf, key, ptr); + } else { + entry = new Entry(_buf, key, ptr); + } - @Nonnull - protected final Entry newEntry(@Nonnull final float[] V) { - Entry entry = newEntry(); - entry.setV(V); + entry.setW(W); return entry; } @Nonnull - private Entry newEntry() { + protected final Entry newEntry(final int key, @Nonnull final float[] V) { + final long ptr; + if (_freelistV.isEmpty()) { + ptr = _buf.allocate(_entrySizeV); + this._numAllocatedV++; + this._bytesAllocated += _entrySizeV; + this._bytesUsed += _entrySizeV; + } else {// reuse removed entry + ptr = _freelistV.remove(); + this._numReusedV++; + } + final Entry entry; if (_useFTRL) { - long ptr = _buf.allocate(_entrySize); - return new FTRLEntry(_buf, _factor, ptr); + entry = new FTRLEntry(_buf, _factor, key, ptr); } else if (_useAdaGrad) { - long ptr = _buf.allocate(_entrySize); - return new AdaGradEntry(_buf, _factor, ptr); + entry = new AdaGradEntry(_buf, _factor, key, ptr); } else { - long ptr = _buf.allocate(_entrySize); - return new Entry(_buf, _factor, ptr); + entry = new Entry(_buf, _factor, key, ptr); } + + entry.setV(V); + return entry; } @Nullable @@ -257,28 +285,95 @@ private Entry getEntry(final int key) { if (ptr == -1L) { return null; } - return getEntry(ptr); + return getEntry(key, ptr); } @Nonnull - private Entry getEntry(long ptr) { - if (_useFTRL) { - return new FTRLEntry(_buf, _factor, ptr); - } else if (_useAdaGrad) { - return new AdaGradEntry(_buf, _factor, ptr); + private Entry getEntry(final int key, @Nonnegative final long ptr) { + if (Entry.isEntryW(key)) { + if (_useFTRL) { + return new FTRLEntry(_buf, key, ptr); + } else if (_useAdaGrad) { + return new AdaGradEntry(_buf, key, ptr); + } else { + return new Entry(_buf, key, ptr); + } } else { - return new Entry(_buf, _factor, ptr); + if (_useFTRL) { + return new FTRLEntry(_buf, _factor, key, ptr); + } else if (_useAdaGrad) { + return new AdaGradEntry(_buf, _factor, key, ptr); + } else { + return new Entry(_buf, _factor, key, ptr); + } } } - private static int entrySize(int factors, boolean ftrl, boolean adagrad) { - if (ftrl) { - return FTRLEntry.sizeOf(factors); - } else if (adagrad) { - return AdaGradEntry.sizeOf(factors); - } else { - return Entry.sizeOf(factors); + @Nonnull + String getStatistics() { + final NumberFormat fmt = NumberFormat.getIntegerInstance(Locale.US); + return "FFMStringFeatureMapModel [bytesAllocated=" + + NumberUtils.prettySize(_bytesAllocated) + ", bytesUsed=" + + NumberUtils.prettySize(_bytesUsed) + ", numAllocatedW=" + + fmt.format(_numAllocatedW) + ", numReusedW=" + fmt.format(_numReusedW) + + ", numRemovedW=" + fmt.format(_numRemovedW) + ", numAllocatedV=" + + fmt.format(_numAllocatedV) + ", numReusedV=" + fmt.format(_numReusedV) + + ", numRemovedV=" + fmt.format(_numRemovedV) + "]"; + } + + @Override + public String toString() { + return getStatistics(); + } + + @Nonnull + EntryIterator entries() { + return new EntryIterator(this); + } + + static final class EntryIterator { + + @Nonnull + private final MapIterator dictItor; + @Nonnull + private final Entry entryProbeW; + @Nonnull + private final Entry entryProbeV; + + EntryIterator(@Nonnull FFMStringFeatureMapModel model) { + this.dictItor = model._map.entries(); + this.entryProbeW = new Entry(model._buf, 1); + this.entryProbeV = new Entry(model._buf, model._factor); + } + + @Nonnull + Entry getEntryProbeW() { + return entryProbeW; } + + @Nonnull + Entry getEntryProbeV() { + return entryProbeV; + } + + boolean hasNext() { + return dictItor.hasNext(); + } + + boolean next() { + return dictItor.next() != -1; + } + + int getEntryIndex() { + return dictItor.getKey(); + } + + @Nonnull + void getEntry(@Nonnull final Entry probe) { + long offset = dictItor.getValue(); + probe.setOffset(offset); + } + } } diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java index accb99a22..15c1c5662 100644 --- a/core/src/main/java/hivemall/fm/FMHyperParameters.java +++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java @@ -143,16 +143,15 @@ public static final class FFMHyperParameters extends FMHyperParameters { int numFields = Feature.DEFAULT_NUM_FIELDS; // adagrad - boolean useAdaGrad = true; - float eta0_V = 1.f; + boolean useAdaGrad = false; float eps = 1.f; // FTRL - boolean useFTRL = true; - float alphaFTRL = 0.1f; // Learning Rate + boolean useFTRL = false; + float alphaFTRL = 0.2f; // Learning Rate float betaFTRL = 1.f; // Smoothing parameter for AdaGrad - float lambda1 = 0.1f; // L1 Regularization - float lamdda2 = 0.01f; // L2 Regularization + float lambda1 = 0.001f; // L1 Regularization + float lamdda2 = 0.0001f; // L2 Regularization FFMHyperParameters() { super(); @@ -171,42 +170,59 @@ void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException { // feature hashing if (numFeatures == -1) { - int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), - Feature.DEFAULT_FEATURE_BITS); - if (hashbits < 18 || hashbits > 31) { - throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " - + hashbits); + int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), -1); + if (hashbits != -1) { + if (hashbits < 18 || hashbits > 31) { + throw new UDFArgumentException( + "-feature_hashing MUST be in range [18,31]: " + hashbits); + } + this.numFeatures = 1 << hashbits; } - this.numFeatures = 1 << hashbits; } this.numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), numFields); if (numFields <= 1) { throw new UDFArgumentException("-num_fields MUST be greater than 1: " + numFields); } - // adagrad - this.useAdaGrad = !cl.hasOption("disable_adagrad"); - this.eta0_V = Primitives.parseFloat(cl.getOptionValue("eta0_V"), eta0_V); - this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), eps); - - // FTRL - this.useFTRL = !cl.hasOption("disable_ftrl"); - this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"), alphaFTRL); - if (alphaFTRL == 0.f) { - throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0"); + // optimizer + final String optimizer = cl.getOptionValue("optimizer", "ftrl").toLowerCase(); + switch (optimizer) { + case "ftrl": { + this.useFTRL = true; + this.useAdaGrad = false; + this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"), + alphaFTRL); + if (alphaFTRL == 0.f) { + throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0"); + } + this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), betaFTRL); + this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), lambda1); + this.lamdda2 = Primitives.parseFloat(cl.getOptionValue("lamdda2"), lamdda2); + break; + } + case "adagrad": { + this.useAdaGrad = true; + this.useFTRL = false; + this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), eps); + break; + } + case "sgd": + // fall through + default: { + this.useFTRL = false; + this.useAdaGrad = false; + break; + } } - this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), betaFTRL); - this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), lambda1); - this.lamdda2 = Primitives.parseFloat(cl.getOptionValue("lamdda2"), lamdda2); } @Override public String toString() { return "FFMHyperParameters [globalBias=" + globalBias + ", linearCoeff=" + linearCoeff - + ", numFields=" + numFields + ", useAdaGrad=" + useAdaGrad + ", eta0_V=" - + eta0_V + ", eps=" + eps + ", useFTRL=" + useFTRL + ", alphaFTRL=" + alphaFTRL - + ", betaFTRL=" + betaFTRL + ", lambda1=" + lambda1 + ", lamdda2=" + lamdda2 - + "], " + super.toString(); + + ", numFields=" + numFields + ", useAdaGrad=" + useAdaGrad + ", eps=" + eps + + ", useFTRL=" + useFTRL + ", alphaFTRL=" + alphaFTRL + ", betaFTRL=" + + betaFTRL + ", lambda1=" + lambda1 + ", lamdda2=" + lamdda2 + "], " + + super.toString(); } } diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java index 19ac287da..be39b0bbe 100644 --- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java @@ -19,7 +19,7 @@ package hivemall.fm; import hivemall.utils.collections.maps.Int2FloatOpenHashTable; -import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashTable; import java.util.Arrays; @@ -33,7 +33,7 @@ public final class FMIntFeatureMapModel extends FactorizationMachineModel { // LEARNING PARAMS private float _w0; private final Int2FloatOpenHashTable _w; - private final IntOpenHashMap _V; + private final IntOpenHashTable _V; private int _minIndex, _maxIndex; @@ -42,7 +42,7 @@ public FMIntFeatureMapModel(@Nonnull FMHyperParameters params) { this._w0 = 0.f; this._w = new Int2FloatOpenHashTable(DEFAULT_MAPSIZE); _w.defaultReturnValue(0.f); - this._V = new IntOpenHashMap(DEFAULT_MAPSIZE); + this._V = new IntOpenHashTable(DEFAULT_MAPSIZE); this._minIndex = 0; this._maxIndex = 0; } diff --git a/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java b/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java index 667befb8c..730cc49d2 100644 --- a/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java +++ b/core/src/main/java/hivemall/fm/FMPredictGenericUDAF.java @@ -18,6 +18,9 @@ */ package hivemall.fm; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_ARRAY_META; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_REF; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES2; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; @@ -35,6 +38,7 @@ import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; @@ -234,6 +238,7 @@ public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuff } + @AggregationType(estimable = true) public static class FMPredictAggregationBuffer extends AbstractAggregationBuffer { private double ret; @@ -328,6 +333,16 @@ void merge(final double o_ret, @Nullable final Object o_sumVjXj, } return predict; } + + @Override + public int estimate() { + if (sumVjXj == null) { + return PRIMITIVES2 + 2 * JAVA64_REF; + } else { + // model.array() = JAVA64_ARRAY_META + JAVA64_REF + return PRIMITIVES2 + 2 * (JAVA64_ARRAY_META + PRIMITIVES2 * sumVjXj.length); + } + } } } diff --git a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java index cd99046d3..4eec280a5 100644 --- a/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java +++ b/core/src/main/java/hivemall/fm/FMStringFeatureMapModel.java @@ -19,7 +19,7 @@ package hivemall.fm; import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.collections.maps.OpenHashMap; import javax.annotation.Nonnull; @@ -28,12 +28,12 @@ public final class FMStringFeatureMapModel extends FactorizationMachineModel { // LEARNING PARAMS private float _w0; - private final OpenHashTable _map; + private final OpenHashMap _map; public FMStringFeatureMapModel(@Nonnull FMHyperParameters params) { super(params); this._w0 = 0.f; - this._map = new OpenHashTable(DEFAULT_MAPSIZE); + this._map = new OpenHashMap(DEFAULT_MAPSIZE); } @Override @@ -42,7 +42,7 @@ public int getSize() { } IMapIterator entries() { - return _map.entries(); + return _map.entries(true); } @Override diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 65b6ba717..24210a889 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -117,8 +117,8 @@ protected Options getOptions() { opts.addOption("c", "classification", false, "Act as classification"); opts.addOption("seed", true, "Seed value [default: -1 (random)]"); opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]"); - opts.addOption("p", "num_features", true, "The size of feature dimensions"); - opts.addOption("factor", "factors", true, "The number of the latent variables [default: 5]"); + opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); + opts.addOption("f", "factors", true, "The number of the latent variables [default: 5]"); opts.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]"); opts.addOption("lambda0", "lambda", true, "The initial lambda value for regularization [default: 0.01]"); @@ -376,7 +376,7 @@ protected void trainTheta(final Feature[] x, final double y) throws HiveExceptio double loss = _lossFunction.loss(p, y); _cvState.incrLoss(loss); - if (MathUtils.closeToZero(lossGrad)) { + if (MathUtils.closeToZero(lossGrad, 1E-9d)) { return; } diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java index 2966a0214..8ae6f203f 100644 --- a/core/src/main/java/hivemall/fm/Feature.java +++ b/core/src/main/java/hivemall/fm/Feature.java @@ -23,6 +23,7 @@ import java.nio.ByteBuffer; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -30,7 +31,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; public abstract class Feature { - public static final int DEFAULT_NUM_FIELDS = 1024; + public static final int DEFAULT_NUM_FIELDS = 256; public static final int DEFAULT_FEATURE_BITS = 21; public static final int DEFAULT_NUM_FEATURES = 1 << 21; // 2^21 @@ -51,10 +52,11 @@ public String getFeature() { throw new UnsupportedOperationException(); } - public void setFeatureIndex(int i) { + public void setFeatureIndex(@Nonnegative int i) { throw new UnsupportedOperationException(); } + @Nonnegative public int getFeatureIndex() { throw new UnsupportedOperationException(); } @@ -127,6 +129,7 @@ public static Feature[] parseFeatures(@Nonnull final Object arg, } } + @Nullable public static Feature[] parseFFMFeatures(@Nonnull final Object arg, @Nonnull final ListObjectInspector listOI, @Nullable final Feature[] probes, final int numFeatures, final int numFields) throws HiveException { @@ -176,6 +179,9 @@ static Feature parseFeature(@Nonnull final String fv, final boolean asIntFeature int index = parseFeatureIndex(fv); return new IntFeature(index, 1.d); } else { + if ("0".equals(fv)) { + throw new HiveException("Index value should not be 0: " + fv); + } return new StringFeature(/* index */fv, 1.d); } } else { @@ -187,6 +193,9 @@ static Feature parseFeature(@Nonnull final String fv, final boolean asIntFeature return new IntFeature(index, value); } else { double value = parseFeatureValue(valueStr); + if ("0".equals(indexStr)) { + throw new HiveException("Index value should not be 0: " + fv); + } return new StringFeature(/* index */indexStr, value); } } @@ -197,6 +206,12 @@ static IntFeature parseFFMFeature(@Nonnull final String fv) throws HiveException return parseFFMFeature(fv, DEFAULT_NUM_FEATURES, DEFAULT_NUM_FIELDS); } + @Nonnull + static IntFeature parseFFMFeature(@Nonnull final String fv, final int numFeatures) + throws HiveException { + return parseFFMFeature(fv, -1, DEFAULT_NUM_FIELDS); + } + @Nonnull static IntFeature parseFFMFeature(@Nonnull final String fv, final int numFeatures, final int numFields) throws HiveException { @@ -219,25 +234,26 @@ static IntFeature parseFFMFeature(@Nonnull final String fv, final int numFeature } else { index = MurmurHash3.murmurhash3(lead, numFields); } - short field = (short) index; + short field = NumberUtils.castToShort(index); double value = parseFeatureValue(rest); return new IntFeature(index, field, value); } - final String indexStr = rest.substring(0, pos2); - final int index; + final short field; - if (NumberUtils.isDigits(indexStr) && NumberUtils.isDigits(lead)) { - index = parseFeatureIndex(indexStr); - if (index >= (numFeatures + numFields)) { - throw new HiveException("Feature index MUST be less than " - + (numFeatures + numFields) + " but was " + index); - } + if (NumberUtils.isDigits(lead)) { field = parseField(lead, numFields); + } else { + field = NumberUtils.castToShort(MurmurHash3.murmurhash3(lead, numFields)); + } + + final int index; + final String indexStr = rest.substring(0, pos2); + if (numFeatures == -1 && NumberUtils.isDigits(indexStr)) { + index = parseFeatureIndex(indexStr); } else { // +NUM_FIELD to avoid conflict to quantitative features index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields; - field = (short) MurmurHash3.murmurhash3(lead, numFields); } String valueStr = rest.substring(pos2 + 1); double value = parseFeatureValue(valueStr); @@ -253,6 +269,9 @@ static void parseFeature(@Nonnull final String fv, @Nonnull final Feature probe, int index = parseFeatureIndex(fv); probe.setFeatureIndex(index); } else { + if ("0".equals(fv)) { + throw new HiveException("Index value should not be 0: " + fv); + } probe.setFeature(fv); } probe.value = 1.d; @@ -264,6 +283,9 @@ static void parseFeature(@Nonnull final String fv, @Nonnull final Feature probe, probe.setFeatureIndex(index); probe.value = parseFeatureValue(valueStr); } else { + if ("0".equals(indexStr)) { + throw new HiveException("Index value should not be 0: " + fv); + } probe.setFeature(indexStr); probe.value = parseFeatureValue(valueStr); } @@ -296,27 +318,26 @@ static void parseFFMFeature(@Nonnull final String fv, @Nonnull final Feature pro } else { index = MurmurHash3.murmurhash3(lead, numFields); } - short field = (short) index; + short field = NumberUtils.castToShort(index); probe.setField(field); probe.setFeatureIndex(index); probe.value = parseFeatureValue(rest); return; } - String indexStr = rest.substring(0, pos2); - final int index; final short field; - if (NumberUtils.isDigits(indexStr) && NumberUtils.isDigits(lead)) { - index = parseFeatureIndex(indexStr); - if (index >= (numFeatures + numFields)) { - throw new HiveException("Feature index MUST be less than " - + (numFeatures + numFields) + " but was " + index); - } + if (NumberUtils.isDigits(lead)) { field = parseField(lead, numFields); + } else { + field = NumberUtils.castToShort(MurmurHash3.murmurhash3(lead, numFields)); + } + final int index; + final String indexStr = rest.substring(0, pos2); + if (numFeatures == -1 && NumberUtils.isDigits(indexStr)) { + index = parseFeatureIndex(indexStr); } else { // +NUM_FIELD to avoid conflict to quantitative features index = MurmurHash3.murmurhash3(indexStr, numFeatures) + numFields; - field = (short) MurmurHash3.murmurhash3(lead, numFields); } probe.setField(field); probe.setFeatureIndex(index); @@ -325,7 +346,6 @@ static void parseFFMFeature(@Nonnull final String fv, @Nonnull final Feature pro probe.value = parseFeatureValue(valueStr); } - private static int parseFeatureIndex(@Nonnull final String indexStr) throws HiveException { final int index; try { @@ -333,7 +353,7 @@ private static int parseFeatureIndex(@Nonnull final String indexStr) throws Hive } catch (NumberFormatException e) { throw new HiveException("Invalid index value: " + indexStr, e); } - if (index < 0) { + if (index <= 0) { throw new HiveException("Feature index MUST be greater than 0: " + indexStr); } return index; @@ -361,7 +381,13 @@ private static short parseField(@Nonnull final String fieldStr, final int numFie return field; } - public static int toIntFeature(@Nonnull final Feature x, final int yField, final int numFields) { + public static int toIntFeature(@Nonnull final Feature x) { + int index = x.getFeatureIndex(); + return -index; + } + + public static int toIntFeature(@Nonnull final Feature x, @Nonnegative final int yField, + @Nonnegative final int numFields) { int index = x.getFeatureIndex(); return index * numFields + yField; } diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java index 76bead8f0..730d0f402 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java @@ -22,9 +22,11 @@ import hivemall.utils.collections.arrays.DoubleArray3D; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.lang.NumberUtils; +import hivemall.utils.math.MathUtils; import java.util.Arrays; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -34,19 +36,33 @@ public abstract class FieldAwareFactorizationMachineModel extends FactorizationM @Nonnull protected final FFMHyperParameters _params; - protected final float _eta0_V; + protected final float _eta0; protected final float _eps; protected final boolean _useAdaGrad; protected final boolean _useFTRL; + // FTEL + private final float _alpha; + private final float _beta; + private final float _lambda1; + private final float _lamdda2; + public FieldAwareFactorizationMachineModel(@Nonnull FFMHyperParameters params) { super(params); this._params = params; - this._eta0_V = params.eta0_V; + if (params.useAdaGrad) { + this._eta0 = 1.0f; + } else { + this._eta0 = params.eta.eta0(); + } this._eps = params.eps; this._useAdaGrad = params.useAdaGrad; this._useFTRL = params.useFTRL; + this._alpha = params.alphaFTRL; + this._beta = params.betaFTRL; + this._lambda1 = params.lambda1; + this._lamdda2 = params.lamdda2; } public abstract float getV(@Nonnull Feature x, @Nonnull int yField, int f); @@ -100,31 +116,152 @@ protected final double predict(@Nonnull final Feature[] x) throws HiveException return ret; } + void updateWi(final double dloss, @Nonnull final Feature x, final long t) { + if (_useFTRL) { + updateWi_FTRL(dloss, x); + return; + } + + final double Xi = x.getValue(); + float gradWi = (float) (dloss * Xi); + + final Entry theta = getEntryW(x); + float wi = theta.getW(); + + final float eta = eta(theta, t, gradWi); + float nextWi = wi - eta * (gradWi + 2.f * _lambdaW * wi); + if (!NumberUtils.isFinite(nextWi)) { + throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() + + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss + + ", eta=" + eta + ", t=" + t); + } + if (MathUtils.closeToZero(nextWi, 1E-9f)) { + removeEntry(theta); + return; + } + theta.setW(nextWi); + } + + /** + * Update Wi using Follow-the-Regularized-Leader + */ + private void updateWi_FTRL(final double dloss, @Nonnull final Feature x) { + final double Xi = x.getValue(); + float gradWi = (float) (dloss * Xi); + + final Entry theta = getEntryW(x); + + final float z = theta.updateZ(gradWi, _alpha); + final double n = theta.updateN(gradWi); + + if (Math.abs(z) <= _lambda1) { + removeEntry(theta); + return; + } + + final float nextWi = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n)) + / _alpha + _lamdda2)); + if (!NumberUtils.isFinite(nextWi)) { + throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() + + "]\n" + "Xi=" + Xi + ", gradWi=" + gradWi + ", wi=" + theta.getW() + + ", dloss=" + dloss + ", n=" + n + ", z=" + z); + } + if (MathUtils.closeToZero(nextWi, 1E-9f)) { + removeEntry(theta); + return; + } + theta.setW(nextWi); + } + + protected abstract void removeEntry(@Nonnull final Entry entry); + void updateV(final double dloss, @Nonnull final Feature x, @Nonnull final int yField, final int f, final double sumViX, long t) { + if (_useFTRL) { + updateV_FTRL(dloss, x, yField, f, sumViX); + return; + } + + final Entry theta = getEntryV(x, yField); + if (theta == null) { + return; + } + final double Xi = x.getValue(); final double h = Xi * sumViX; final float gradV = (float) (dloss * h); final float lambdaVf = getLambdaV(f); - final Entry theta = getEntry(x, yField); final float currentV = theta.getV(f); - final float eta = etaV(theta, t, gradV); + final float eta = eta(theta, f, t, gradV); final float nextV = currentV - eta * (gradV + 2.f * lambdaVf * currentV); if (!NumberUtils.isFinite(nextV)) { throw new IllegalStateException("Got " + nextV + " for next V" + f + '[' + x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + currentV + ", h=" + h + ", gradV=" + gradV + ", lambdaVf=" + lambdaVf + ", dloss=" + dloss - + ", sumViX=" + sumViX); + + ", sumViX=" + sumViX + ", t=" + t); + } + if (MathUtils.closeToZero(nextV, 1E-9f)) { + theta.setV(f, 0.f); + if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled + removeEntry(theta); + } + return; + } + theta.setV(f, nextV); + } + + private void updateV_FTRL(final double dloss, @Nonnull final Feature x, + @Nonnull final int yField, final int f, final double sumViX) { + final Entry theta = getEntryV(x, yField); + if (theta == null) { + return; + } + + final double Xi = x.getValue(); + final double h = Xi * sumViX; + final float gradV = (float) (dloss * h); + + float oldV = theta.getV(f); + final float z = theta.updateZ(f, oldV, gradV, _alpha); + final double n = theta.updateN(f, gradV); + + if (Math.abs(z) <= _lambda1) { + theta.setV(f, 0.f); + if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled + removeEntry(theta); + } + return; + } + + final float nextV = (float) ((MathUtils.sign(z) * _lambda1 - z) / ((_beta + Math.sqrt(n)) + / _alpha + _lamdda2)); + if (!NumberUtils.isFinite(nextV)) { + throw new IllegalStateException("Got " + nextV + " for next V" + f + '[' + + x.getFeatureIndex() + "]\n" + "Xi=" + Xi + ", Vif=" + theta.getV(f) + ", h=" + + h + ", gradV=" + gradV + ", dloss=" + dloss + ", sumViX=" + sumViX + ", n=" + + n + ", z=" + z); + } + if (MathUtils.closeToZero(nextV, 1E-9f)) { + theta.setV(f, 0.f); + if (theta.removable()) { // Whether other factors are zero filled or not? Remove if zero filled + removeEntry(theta); + } + return; } theta.setV(f, nextV); } - protected final float etaV(@Nonnull final Entry theta, final long t, final float grad) { + protected final float eta(@Nonnull final Entry theta, final long t, final float grad) { + return eta(theta, 0, t, grad); + } + + protected final float eta(@Nonnull final Entry theta, @Nonnegative final int f, final long t, + final float grad) { if (_useAdaGrad) { - double gg = theta.getSumOfSquaredGradientsV(); - theta.addGradientV(grad); - return (float) (_eta0_V / Math.sqrt(_eps + gg)); + double gg = theta.getSumOfSquaredGradients(f); + theta.addGradient(f, grad); + return (float) (_eta0 / Math.sqrt(_eps + gg)); } else { return _eta.eta(t); } @@ -187,10 +324,10 @@ private double sumVfX(@Nonnull final Feature[] x, final int i, @Nonnull final in } @Nonnull - protected abstract Entry getEntry(@Nonnull Feature x); + protected abstract Entry getEntryW(@Nonnull Feature x); - @Nonnull - protected abstract Entry getEntry(@Nonnull Feature x, @Nonnull int yField); + @Nullable + protected abstract Entry getEntryV(@Nonnull Feature x, @Nonnull int yField); @Override protected final String varDump(@Nonnull final Feature[] x) { diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java index 67dbf87cf..56d9dc208 100644 --- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java @@ -18,17 +18,18 @@ */ package hivemall.fm; +import hivemall.fm.FFMStringFeatureMapModel.EntryIterator; import hivemall.fm.FMHyperParameters.FFMHyperParameters; import hivemall.utils.collections.arrays.DoubleArray3D; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HadoopUtils; -import hivemall.utils.hadoop.Text3; -import hivemall.utils.lang.NumberUtils; +import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.math.MathUtils; -import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -44,6 +45,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; /** @@ -60,8 +63,6 @@ public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachi // ---------------------------------------- // Learning hyper-parameters/options - private boolean _FTRL; - private boolean _globalBias; private boolean _linearCoeff; @@ -87,26 +88,25 @@ protected Options getOptions() { opts.addOption("disable_wi", "no_coeff", false, "Not to include linear term [default: OFF]"); // feature hashing opts.addOption("feature_hashing", true, - "The number of bits for feature hashing in range [18,31] [default:21]"); - opts.addOption("num_fields", true, "The number of fields [default:1024]"); + "The number of bits for feature hashing in range [18,31] [default: -1]. No feature hashing for -1."); + opts.addOption("num_fields", true, "The number of fields [default: 256]"); + // optimizer + opts.addOption("opt", "optimizer", true, + "Gradient Descent optimizer [default: ftrl, adagrad, sgd]"); // adagrad - opts.addOption("disable_adagrad", false, - "Whether to use AdaGrad for tuning learning rate [default: ON]"); - opts.addOption("eta0_V", true, "The initial learning rate for V [default 1.0]"); - opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]"); + opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default: 1.0]"); // FTRL - opts.addOption("disable_ftrl", false, - "Whether not to use Follow-The-Regularized-Reader [default: OFF]"); opts.addOption("alpha", "alphaFTRL", true, - "Alpha value (learning rate) of Follow-The-Regularized-Reader [default 0.1]"); + "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.2]"); opts.addOption("beta", "betaFTRL", true, - "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default 1.0]"); + "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default: 1.0]"); opts.addOption( + "l1", "lambda1", true, - "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default 0.1]"); - opts.addOption("lambda2", true, - "L2 regularization value of Follow-The-Regularized-Reader [default 0.01]"); + "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.001]"); + opts.addOption("l2", "lambda2", true, + "L2 regularization value of Follow-The-Regularized-Reader [default: 0.0001]"); return opts; } @@ -125,7 +125,6 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen CommandLine cl = super.processOptions(argOIs); FFMHyperParameters params = (FFMHyperParameters) _params; - this._FTRL = params.useFTRL; this._globalBias = params.globalBias; this._linearCoeff = params.linearCoeff; this._numFeatures = params.numFeatures; @@ -150,8 +149,14 @@ protected StructObjectInspector getOutputOI(@Nonnull FMHyperParameters params) { fieldNames.add("model_id"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); - fieldNames.add("model"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("i"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("Wi"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("Vi"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @@ -184,20 +189,19 @@ public void train(@Nonnull final Feature[] x, final double y, @Override protected void trainTheta(@Nonnull final Feature[] x, final double y) throws HiveException { - final float eta_t = _etaEstimator.eta(_t); - final double p = _ffmModel.predict(x); final double lossGrad = _ffmModel.dloss(p, y); double loss = _lossFunction.loss(p, y); _cvState.incrLoss(loss); - if (MathUtils.closeToZero(lossGrad)) { + if (MathUtils.closeToZero(lossGrad, 1E-9d)) { return; } // w0 update if (_globalBias) { + float eta_t = _etaEstimator.eta(_t); _ffmModel.updateW0(lossGrad, eta_t); } @@ -210,14 +214,16 @@ protected void trainTheta(@Nonnull final Feature[] x, final double y) throws Hiv if (x_i.value == 0.f) { continue; } - boolean useV = updateWi(lossGrad, x_i, eta_t); // wi update - if (useV == false) { - continue; + if (_linearCoeff) { + _ffmModel.updateWi(lossGrad, x_i, _t);// wi update } for (int fieldIndex = 0, size = fieldList.size(); fieldIndex < size; fieldIndex++) { final int yField = fieldList.get(fieldIndex); for (int f = 0, k = _factors; f < k; f++) { - double sumViX = sumVfX.get(i, fieldIndex, f); + final double sumViX = sumVfX.get(i, fieldIndex, f); + if (MathUtils.closeToZero(sumViX)) {// grad will be 0 => skip it + continue; + } _ffmModel.updateV(lossGrad, x_i, yField, f, sumViX, _t); } } @@ -229,18 +235,6 @@ protected void trainTheta(@Nonnull final Feature[] x, final double y) throws Hiv fieldList.clear(); } - private boolean updateWi(double lossGrad, @Nonnull Feature xi, float eta) { - if (!_linearCoeff) { - return true; - } - if (_FTRL) { - return _ffmModel.updateWiFTRL(lossGrad, xi, eta); - } else { - _ffmModel.updateWi(lossGrad, xi, eta); - return true; - } - } - @Nonnull private IntArrayList getFieldList(@Nonnull final Feature[] x) { for (Feature e : x) { @@ -257,7 +251,16 @@ protected IntFeature instantiateFeature(@Nonnull final ByteBuffer input) { @Override public void close() throws HiveException { + if (LOG.isInfoEnabled()) { + LOG.info(_ffmModel.getStatistics()); + } + + _ffmModel.disableInitV(); // trick to avoid re-instantiating removed (zero-filled) entry of V super.close(); + + if (LOG.isInfoEnabled()) { + LOG.info(_ffmModel.getStatistics()); + } this._ffmModel = null; } @@ -267,39 +270,54 @@ protected void forwardModel() throws HiveException { this._fieldList = null; this._sumVfX = null; - Text modelId = new Text(); - String taskId = HadoopUtils.getUniqueTaskIdString(); - modelId.set(taskId); - - FFMPredictionModel predModel = _ffmModel.toPredictionModel(); - this._ffmModel = null; // help GC - - if (LOG.isInfoEnabled()) { - LOG.info("Serializing a model '" + modelId + "'... Configured # features: " - + _numFeatures + ", Configured # fields: " + _numFields - + ", Actual # features: " + predModel.getActualNumFeatures() - + ", Estimated uncompressed bytes: " - + NumberUtils.prettySize(predModel.approxBytesConsumed())); - } + final int factors = _factors; + final IntWritable idx = new IntWritable(); + final FloatWritable Wi = new FloatWritable(0.f); + final FloatWritable[] Vi = HiveUtils.newFloatArray(factors, 0.f); + final List ViObj = Arrays.asList(Vi); + + final Object[] forwardObjs = new Object[4]; + String modelId = HadoopUtils.getUniqueTaskIdString(); + forwardObjs[0] = new Text(modelId); + forwardObjs[1] = idx; + forwardObjs[2] = Wi; + forwardObjs[3] = null; // Vi + + // W0 + idx.set(0); + Wi.set(_ffmModel.getW0()); + forward(forwardObjs); - byte[] serialized; - try { - serialized = predModel.serialize(); - predModel = null; - } catch (IOException e) { - throw new HiveException("Failed to serialize a model", e); - } + final EntryIterator itor = _ffmModel.entries(); + final Entry entryW = itor.getEntryProbeW(); + final Entry entryV = itor.getEntryProbeV(); + final float[] Vf = new float[factors]; + while (itor.next()) { + // set i + int i = itor.getEntryIndex(); + idx.set(i); + + if (Entry.isEntryW(i)) {// set Wi + itor.getEntry(entryW); + float w = entryV.getW(); + if (w == 0.f) { + continue; // skip w_i=0 + } + Wi.set(w); + forwardObjs[2] = Wi; + forwardObjs[3] = null; + } else {// set Vif + itor.getEntry(entryV); + entryV.getV(Vf); + for (int f = 0; f < factors; f++) { + Vi[f].set(Vf[f]); + } + forwardObjs[2] = null; + forwardObjs[3] = ViObj; + } - if (LOG.isInfoEnabled()) { - LOG.info("Forwarding a serialized/compressed model '" + modelId + "' of size: " - + NumberUtils.prettySize(serialized.length)); + forward(forwardObjs); } - - Text modelObj = new Text3(serialized); - serialized = null; - Object[] forwardObjs = new Object[] {modelId, modelObj}; - - forward(forwardObjs); } } diff --git a/core/src/main/java/hivemall/fm/IntFeature.java b/core/src/main/java/hivemall/fm/IntFeature.java index 2052f7e07..64a4daafd 100644 --- a/core/src/main/java/hivemall/fm/IntFeature.java +++ b/core/src/main/java/hivemall/fm/IntFeature.java @@ -20,19 +20,21 @@ import java.nio.ByteBuffer; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; public final class IntFeature extends Feature { + @Nonnegative private int index; /** -1 if not defined */ private short field; - public IntFeature(int index, double value) { + public IntFeature(@Nonnegative int index, double value) { this(index, (short) -1, value); } - public IntFeature(int index, short field, double value) { + public IntFeature(@Nonnegative int index, short field, double value) { super(value); this.field = field; this.index = index; diff --git a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java index 6aebd64ab..3ec6ad7b4 100644 --- a/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java +++ b/core/src/main/java/hivemall/ftvec/pairing/FeaturePairsUDTF.java @@ -19,15 +19,18 @@ package hivemall.ftvec.pairing; import hivemall.UDTFWithOptions; +import hivemall.fm.Feature; import hivemall.model.FeatureValue; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hashing.HashFunction; import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.Primitives; import java.util.ArrayList; import java.util.List; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; @@ -50,6 +53,8 @@ public final class FeaturePairsUDTF extends UDTFWithOptions { private Type _type; private RowProcessor _proc; + private int _numFields; + private int _numFeatures; public FeaturePairsUDTF() {} @@ -57,9 +62,14 @@ public FeaturePairsUDTF() {} protected Options getOptions() { Options opts = new Options(); opts.addOption("kpa", false, - "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:true]"); + "Generate feature pairs for Kernel-Expansion Passive Aggressive [default:false]"); opts.addOption("ffm", false, "Generate feature pairs for Field-aware Factorization Machines [default:false]"); + // feature hashing + opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); + opts.addOption("feature_hashing", true, + "The number of bits for feature hashing in range [18,31]. [default: -1] No feature hashing for -1."); + opts.addOption("num_fields", true, "The number of fields [default:1024]"); return opts; } @@ -70,13 +80,30 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen String args = HiveUtils.getConstString(argOIs[1]); cl = parseOptions(args); - Preconditions.checkArgument(cl.getOptions().length == 1, UDFArgumentException.class, - "Only one option can be specified: " + cl.getArgList()); + Preconditions.checkArgument(cl.getOptions().length <= 3, UDFArgumentException.class, + "Too many options were specified: " + cl.getArgList()); if (cl.hasOption("kpa")) { this._type = Type.kpa; } else if (cl.hasOption("ffm")) { this._type = Type.ffm; + this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), -1); + if (_numFeatures == -1) { + int featureBits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), -1); + if (featureBits != -1) { + if (featureBits < 18 || featureBits > 31) { + throw new UDFArgumentException( + "-feature_hashing MUST be in range [18,31]: " + featureBits); + } + this._numFeatures = 1 << featureBits; + } + } + this._numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), + Feature.DEFAULT_NUM_FIELDS); + if (_numFields <= 1) { + throw new UDFArgumentException("-num_fields MUST be greater than 1: " + + _numFields); + } } else { throw new UDFArgumentException("Unsupported option: " + cl.getArgList().get(0)); } @@ -113,8 +140,16 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu break; } case ffm: { - throw new UDFArgumentException("-ffm is not supported yet"); - //break; + this._proc = new FFMProcessor(fvOI); + fieldNames.add("i"); // index + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("j"); // index + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("xi"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("xj"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + break; } default: throw new UDFArgumentException("Illegal condition: " + _type); @@ -144,26 +179,7 @@ abstract class RowProcessor { this.fvOI = fvOI; } - void process(@Nonnull Object arg) throws HiveException { - final int size = fvOI.getListLength(arg); - if (size == 0) { - return; - } - - final List features = new ArrayList(size); - for (int i = 0; i < size; i++) { - Object f = fvOI.getListElement(arg, i); - if (f == null) { - continue; - } - FeatureValue fv = FeatureValue.parse(f, true); - features.add(fv); - } - - process(features); - } - - abstract void process(@Nonnull List features) throws HiveException; + abstract void process(@Nonnull Object arg) throws HiveException; } @@ -186,7 +202,22 @@ final class KPAProcessor extends RowProcessor { } @Override - void process(@Nonnull List features) throws HiveException { + void process(@Nonnull Object arg) throws HiveException { + final int size = fvOI.getListLength(arg); + if (size == 0) { + return; + } + + final List features = new ArrayList(size); + for (int i = 0; i < size; i++) { + Object f = fvOI.getListElement(arg, i); + if (f == null) { + continue; + } + FeatureValue fv = FeatureValue.parse(f, true); + features.add(fv); + } + forward[0] = f0; f0.set(0); forward[1] = null; @@ -222,6 +253,78 @@ void process(@Nonnull List features) throws HiveException { } } + final class FFMProcessor extends RowProcessor { + + @Nonnull + private final IntWritable f0, f1; + @Nonnull + private final DoubleWritable f2, f3; + @Nonnull + private final Writable[] forward; + + @Nullable + private transient Feature[] _features; + + FFMProcessor(@Nonnull ListObjectInspector fvOI) { + super(fvOI); + this.f0 = new IntWritable(); + this.f1 = new IntWritable(); + this.f2 = new DoubleWritable(); + this.f3 = new DoubleWritable(); + this.forward = new Writable[] {f0, null, null, null}; + this._features = null; + } + + @Override + void process(@Nonnull Object arg) throws HiveException { + final int size = fvOI.getListLength(arg); + if (size == 0) { + return; + } + + this._features = Feature.parseFFMFeatures(arg, fvOI, _features, _numFeatures, + _numFields); + + // W0 + f0.set(0); + forward[1] = null; + forward[2] = null; + forward[3] = null; + forward(forward); + + forward[2] = f2; + final Feature[] features = _features; + for (int i = 0, len = features.length; i < len; i++) { + Feature ei = features[i]; + + // Wi + f0.set(Feature.toIntFeature(ei)); + forward[1] = null; + f2.set(ei.getValue()); + forward[3] = null; + forward(forward); + + forward[1] = f1; + forward[3] = f3; + final int iField = ei.getField(); + for (int j = i + 1; j < len; j++) { + Feature ej = features[j]; + double xj = ej.getValue(); + int jField = ej.getField(); + + int ifj = Feature.toIntFeature(ei, jField, _numFields); + int jfi = Feature.toIntFeature(ej, iField, _numFields); + + // Vifj, Vjfi + f0.set(ifj); + f1.set(jfi); + // `f2` is consistently set to `xi` + f3.set(xj); + forward(forward); + } + } + } + } @Override public void close() throws HiveException { diff --git a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java index 5e9f7971f..cdba00b0a 100644 --- a/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java +++ b/core/src/main/java/hivemall/ftvec/ranking/PositiveOnlyFeedback.java @@ -19,8 +19,8 @@ package hivemall.ftvec.ranking; import hivemall.utils.collections.lists.IntArrayList; -import hivemall.utils.collections.maps.IntOpenHashMap; -import hivemall.utils.collections.maps.IntOpenHashMap.IMapIterator; +import hivemall.utils.collections.maps.IntOpenHashTable; +import hivemall.utils.collections.maps.IntOpenHashTable.IMapIterator; import java.util.BitSet; @@ -30,13 +30,13 @@ public class PositiveOnlyFeedback { @Nonnull - protected final IntOpenHashMap rows; + protected final IntOpenHashTable rows; protected int maxItemId; protected int totalFeedbacks; public PositiveOnlyFeedback(int maxItemId) { - this.rows = new IntOpenHashMap(1024); + this.rows = new IntOpenHashTable(1024); this.maxItemId = maxItemId; this.totalFeedbacks = 0; } diff --git a/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java new file mode 100644 index 000000000..53b998cb5 --- /dev/null +++ b/core/src/main/java/hivemall/ftvec/trans/AddFieldIndicesUDF.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.trans; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +@Description(name = "add_field_indicies", value = "_FUNC_(array features) " + + "- Returns arrays of string that field indicies (:)* are argumented") +@UDFType(deterministic = true, stateful = false) +public final class AddFieldIndicesUDF extends GenericUDF { + + private ListObjectInspector listOI; + + @Override + public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) + throws UDFArgumentException { + if (argOIs.length != 1) { + throw new UDFArgumentException("Expected a single argument: " + argOIs.length); + } + + this.listOI = HiveUtils.asListOI(argOIs[0]); + if (!HiveUtils.isStringOI(listOI.getListElementObjectInspector())) { + throw new UDFArgumentException("Expected array but got " + argOIs[0]); + } + + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector); + } + + @Override + public List evaluate(@Nonnull DeferredObject[] args) throws HiveException { + Preconditions.checkArgument(args.length == 1); + + final String[] features = HiveUtils.asStringArray(args[0], listOI); + if (features == null) { + return null; + } + + final List argumented = new ArrayList<>(features.length); + for (int i = 0; i < features.length; i++) { + final String f = features[i]; + if (f == null) { + continue; + } + argumented.add((i + 1) + ":" + f); + } + + return argumented; + } + + @Override + public String getDisplayString(String[] args) { + return "add_field_indicies( " + Arrays.toString(args) + " )"; + } + + +} diff --git a/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java index 98617bd21..4722efd07 100644 --- a/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/CategoricalFeaturesUDF.java @@ -18,6 +18,7 @@ */ package hivemall.ftvec.trans; +import hivemall.UDFWithOptions; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; @@ -26,26 +27,55 @@ import javax.annotation.Nonnull; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.io.Text; -@Description(name = "categorical_features", - value = "_FUNC_(array featureNames, ...) - Returns a feature vector array") +@Description( + name = "categorical_features", + value = "_FUNC_(array featureNames, feature1, feature2, .. [, const string options])" + + " - Returns a feature vector array") @UDFType(deterministic = true, stateful = false) -public final class CategoricalFeaturesUDF extends GenericUDF { +public final class CategoricalFeaturesUDF extends UDFWithOptions { - private String[] featureNames; - private PrimitiveObjectInspector[] inputOIs; - private List result; + private String[] _featureNames; + private PrimitiveObjectInspector[] _inputOIs; + private List _result; + + private boolean _emitNull = false; + private boolean _forceValue = false; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("no_elim", "no_elimination", false, + "Wheather to emit NULL and value [default: false]"); + opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]"); + opts.addOption("force_value", false, "Wheather to force emit value [default: false]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + if (cl.hasOption("no_elim")) { + this._emitNull = true; + this._forceValue = true; + } else { + this._emitNull = cl.hasOption("emit_null"); + this._forceValue = cl.hasOption("force_value"); + } + return cl; + } @Override public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) @@ -55,54 +85,91 @@ public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: " + numArgOIs); } - this.featureNames = HiveUtils.getConstStringArray(argOIs[0]); - if (featureNames == null) { + + this._featureNames = HiveUtils.getConstStringArray(argOIs[0]); + if (_featureNames == null) { throw new UDFArgumentException("#featureNames should not be null"); } - int numFeatureNames = featureNames.length; + int numFeatureNames = _featureNames.length; if (numFeatureNames < 1) { throw new UDFArgumentException("#featureNames must be greater than or equals to 1: " + numFeatureNames); } - int numFeatures = numArgOIs - 1; + for (String featureName : _featureNames) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { + throw new UDFArgumentException("featureName should not include colon: " + + featureName); + } + } + + final int numFeatures; + final int lastArgIndex = numArgOIs - 1; + if (lastArgIndex > numFeatureNames) { + if (lastArgIndex == (numFeatureNames + 1) + && HiveUtils.isConstString(argOIs[lastArgIndex])) { + String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]); + processOptions(optionValue); + numFeatures = numArgOIs - 2; + } else { + throw new UDFArgumentException( + "Unexpected arguments for _FUNC_" + + "(const array featureNames, feature1, feature2, .. [, const string options])"); + } + } else { + numFeatures = lastArgIndex; + } if (numFeatureNames != numFeatures) { - throw new UDFArgumentException("#featureNames '" + numFeatureNames - + "' != #arguments '" + numFeatures + "'"); + throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames + + "' != #features '" + numFeatures + "'"); } - this.inputOIs = new PrimitiveObjectInspector[numFeatures]; + this._inputOIs = new PrimitiveObjectInspector[numFeatures]; for (int i = 0; i < numFeatures; i++) { ObjectInspector oi = argOIs[i + 1]; - inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); + _inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); } - this.result = new ArrayList(numFeatures); + this._result = new ArrayList(numFeatures); - return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector); } @Override - public List evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { - result.clear(); + public List evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { + _result.clear(); - final int size = arguments.length - 1; + final int size = _featureNames.length; for (int i = 0; i < size; i++) { Object argument = arguments[i + 1].get(); if (argument == null) { + if (_emitNull) { + _result.add(null); + } continue; } - PrimitiveObjectInspector oi = inputOIs[i]; + PrimitiveObjectInspector oi = _inputOIs[i]; String s = PrimitiveObjectInspectorUtils.getString(argument, oi); if (s.isEmpty()) { + if (_emitNull) { + _result.add(null); + } continue; } - // categorical feature representation - String featureName = featureNames[i]; - Text f = new Text(featureName + '#' + s); - result.add(f); + // categorical feature representation + final String f; + if (_forceValue) { + f = _featureNames[i] + '#' + s + ":1"; + } else { + f = _featureNames[i] + '#' + s; + } + _result.add(f); + } - return result; + return _result; } @Override diff --git a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java index c98ffdaad..eead73849 100644 --- a/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/FFMFeaturesUDF.java @@ -23,6 +23,7 @@ import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hashing.MurmurHash3; import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.StringUtils; import java.util.ArrayList; import java.util.Arrays; @@ -59,6 +60,7 @@ public final class FFMFeaturesUDF extends UDFWithOptions { private boolean _mhash = true; private int _numFeatures = Feature.DEFAULT_NUM_FEATURES; private int _numFields = Feature.DEFAULT_NUM_FIELDS; + private boolean _emitIndicies = false; @Override protected Options getOptions() { @@ -66,9 +68,11 @@ protected Options getOptions() { opts.addOption("no_hash", "disable_feature_hashing", false, "Wheather to disable feature hashing [default: false]"); // feature hashing + opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]"); opts.addOption("hash", "feature_hashing", true, "The number of bits for feature hashing in range [18,31] [default:21]"); opts.addOption("fields", "num_fields", true, "The number of fields [default:1024]"); + opts.addOption("emit_indicies", false, "Emit indicies for fields [default: false]"); return opts; } @@ -77,19 +81,27 @@ protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgu CommandLine cl = parseOptions(optionValue); // feature hashing - int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), - Feature.DEFAULT_FEATURE_BITS); - if (hashbits < 18 || hashbits > 31) { - throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " + hashbits); + int numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), -1); + if (numFeatures == -1) { + int hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), + Feature.DEFAULT_FEATURE_BITS); + if (hashbits < 18 || hashbits > 31) { + throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " + + hashbits); + } + numFeatures = 1 << hashbits; } - int numFeatures = 1 << hashbits; + this._numFeatures = numFeatures; + int numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), Feature.DEFAULT_NUM_FIELDS); if (numFields <= 1) { throw new UDFArgumentException("-num_fields MUST be greater than 1: " + numFields); } - this._numFeatures = numFeatures; this._numFields = numFields; + + this._emitIndicies = cl.hasOption("emit_indicies"); + return cl; } @@ -111,7 +123,10 @@ public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) + numFeatureNames); } for (String featureName : _featureNames) { - if (featureName.indexOf(':') != -1) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { throw new UDFArgumentException("featureName should not include colon: " + featureName); } @@ -174,18 +189,20 @@ public List evaluate(@Nonnull final DeferredObject[] arguments) throws Hiv // categorical feature representation final String fv; if (_mhash) { - int field = MurmurHash3.murmurhash3(_featureNames[i], _numFields); + int field = _emitIndicies ? i : MurmurHash3.murmurhash3(_featureNames[i], + _numFields); // +NUM_FIELD to avoid conflict to quantitative features int index = MurmurHash3.murmurhash3(feature, _numFeatures) + _numFields; fv = builder.append(field).append(':').append(index).append(":1").toString(); - builder.setLength(0); + StringUtils.clear(builder); } else { - fv = builder.append(featureName) - .append(':') - .append(feature) - .append(":1") - .toString(); - builder.setLength(0); + if (_emitIndicies) { + builder.append(i); + } else { + builder.append(featureName); + } + fv = builder.append(':').append(feature).append(":1").toString(); + StringUtils.clear(builder); } _result.add(new Text(fv)); diff --git a/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java b/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java index 2886996ab..846be97d8 100644 --- a/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java +++ b/core/src/main/java/hivemall/ftvec/trans/QuantifiedFeaturesUDTF.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; @@ -39,7 +40,7 @@ @Description( name = "quantified_features", - value = "_FUNC_(boolean output, col1, col2, ...) - Returns an identified features in a dence array") + value = "_FUNC_(boolean output, col1, col2, ...) - Returns an identified features in a dense array") public final class QuantifiedFeaturesUDTF extends GenericUDTF { private BooleanObjectInspector boolOI; @@ -76,8 +77,8 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu } } - ArrayList fieldNames = new ArrayList(outputSize); - ArrayList fieldOIs = new ArrayList(outputSize); + List fieldNames = new ArrayList(outputSize); + List fieldOIs = new ArrayList(outputSize); fieldNames.add("features"); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); diff --git a/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java index 43f837fa4..38e35e2a4 100644 --- a/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/QuantitativeFeaturesUDF.java @@ -18,6 +18,7 @@ */ package hivemall.ftvec.trans; +import hivemall.UDFWithOptions; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; @@ -26,11 +27,13 @@ import javax.annotation.Nonnull; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -39,14 +42,32 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.Text; -@Description(name = "quantitative_features", - value = "_FUNC_(array featureNames, ...) - Returns a feature vector array") +@Description( + name = "quantitative_features", + value = "_FUNC_(array featureNames, feature1, feature2, .. [, const string options])" + + " - Returns a feature vector array") @UDFType(deterministic = true, stateful = false) -public final class QuantitativeFeaturesUDF extends GenericUDF { +public final class QuantitativeFeaturesUDF extends UDFWithOptions { - private String[] featureNames; - private PrimitiveObjectInspector[] inputOIs; - private List result; + private String[] _featureNames; + private PrimitiveObjectInspector[] _inputOIs; + private List _result; + + private boolean _emitNull = false; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + this._emitNull = cl.hasOption("emit_null"); + return cl; + } @Override public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) @@ -56,58 +77,92 @@ public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: " + numArgOIs); } - this.featureNames = HiveUtils.getConstStringArray(argOIs[0]); - if (featureNames == null) { + + this._featureNames = HiveUtils.getConstStringArray(argOIs[0]); + if (_featureNames == null) { throw new UDFArgumentException("#featureNames should not be null"); } - int numFeatureNames = featureNames.length; + int numFeatureNames = _featureNames.length; if (numFeatureNames < 1) { throw new UDFArgumentException("#featureNames must be greater than or equals to 1: " + numFeatureNames); } - int numFeatures = numArgOIs - 1; + for (String featureName : _featureNames) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { + throw new UDFArgumentException("featureName should not include colon: " + + featureName); + } + } + + final int numFeatures; + final int lastArgIndex = numArgOIs - 1; + if (lastArgIndex > numFeatureNames) { + if (lastArgIndex == (numFeatureNames + 1) + && HiveUtils.isConstString(argOIs[lastArgIndex])) { + String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]); + processOptions(optionValue); + numFeatures = numArgOIs - 2; + } else { + throw new UDFArgumentException( + "Unexpected arguments for _FUNC_" + + "(const array featureNames, feature1, feature2, .. [, const string options])"); + } + } else { + numFeatures = lastArgIndex; + } if (numFeatureNames != numFeatures) { - throw new UDFArgumentException("#featureNames '" + numFeatureNames - + "' != #arguments '" + numFeatures + "'"); + throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames + + "' != #features '" + numFeatures + "'"); } - this.inputOIs = new PrimitiveObjectInspector[numFeatures]; + this._inputOIs = new PrimitiveObjectInspector[numFeatures]; for (int i = 0; i < numFeatures; i++) { ObjectInspector oi = argOIs[i + 1]; - inputOIs[i] = HiveUtils.asDoubleCompatibleOI(oi); + _inputOIs[i] = HiveUtils.asDoubleCompatibleOI(oi); } - this.result = new ArrayList(numFeatures); + this._result = new ArrayList(numFeatures); return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); } @Override public List evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { - result.clear(); + _result.clear(); - final int size = arguments.length - 1; + final int size = _featureNames.length; for (int i = 0; i < size; i++) { Object argument = arguments[i + 1].get(); if (argument == null) { + if (_emitNull) { + _result.add(null); + } continue; } - PrimitiveObjectInspector oi = inputOIs[i]; + PrimitiveObjectInspector oi = _inputOIs[i]; if (oi.getPrimitiveCategory() == PrimitiveCategory.STRING) { String s = argument.toString(); if (s.isEmpty()) { + if (_emitNull) { + _result.add(null); + } continue; } } final double v = PrimitiveObjectInspectorUtils.getDouble(argument, oi); if (v != 0.d) { - String featureName = featureNames[i]; - Text f = new Text(featureName + ':' + v); - result.add(f); + Text f = new Text(_featureNames[i] + ':' + v); + _result.add(f); + } else if (_emitNull) { + Text f = new Text(_featureNames[i] + ":0"); + _result.add(f); } } - return result; + return _result; } @Override diff --git a/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java b/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java index 48bf12665..f2ecbb64b 100644 --- a/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java +++ b/core/src/main/java/hivemall/ftvec/trans/VectorizeFeaturesUDF.java @@ -18,6 +18,7 @@ */ package hivemall.ftvec.trans; +import hivemall.UDFWithOptions; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.lang.StringUtils; @@ -27,11 +28,13 @@ import javax.annotation.Nonnull; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -40,14 +43,32 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.io.Text; -@Description(name = "vectorize_features", - value = "_FUNC_(array featureNames, ...) - Returns a feature vector array") +@Description( + name = "vectorize_features", + value = "_FUNC_(array featureNames, feature1, feature2, .. [, const string options])" + + " - Returns a feature vector array") @UDFType(deterministic = true, stateful = false) -public final class VectorizeFeaturesUDF extends GenericUDF { +public final class VectorizeFeaturesUDF extends UDFWithOptions { - private String[] featureNames; - private PrimitiveObjectInspector[] inputOIs; - private List result; + private String[] _featureNames; + private PrimitiveObjectInspector[] _inputOIs; + private List _result; + + private boolean _emitNull = false; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("emit_null", false, "Wheather to emit NULL [default: false]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + this._emitNull = cl.hasOption("emit_null"); + return cl; + } @Override public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) @@ -57,63 +78,96 @@ public ObjectInspector initialize(@Nonnull final ObjectInspector[] argOIs) throw new UDFArgumentException("argOIs.length must be greater that or equals to 2: " + numArgOIs); } - this.featureNames = HiveUtils.getConstStringArray(argOIs[0]); - if (featureNames == null) { + + this._featureNames = HiveUtils.getConstStringArray(argOIs[0]); + if (_featureNames == null) { throw new UDFArgumentException("#featureNames should not be null"); } - int numFeatureNames = featureNames.length; + int numFeatureNames = _featureNames.length; if (numFeatureNames < 1) { throw new UDFArgumentException("#featureNames must be greater than or equals to 1: " + numFeatureNames); } - int numFeatures = numArgOIs - 1; + for (String featureName : _featureNames) { + if (featureName == null) { + throw new UDFArgumentException("featureName should not be null: " + + Arrays.toString(_featureNames)); + } else if (featureName.indexOf(':') != -1) { + throw new UDFArgumentException("featureName should not include colon: " + + featureName); + } + } + + final int numFeatures; + final int lastArgIndex = numArgOIs - 1; + if (lastArgIndex > numFeatureNames) { + if (lastArgIndex == (numFeatureNames + 1) + && HiveUtils.isConstString(argOIs[lastArgIndex])) { + String optionValue = HiveUtils.getConstString(argOIs[lastArgIndex]); + processOptions(optionValue); + numFeatures = numArgOIs - 2; + } else { + throw new UDFArgumentException( + "Unexpected arguments for _FUNC_" + + "(const array featureNames, feature1, feature2, .. [, const string options])"); + } + } else { + numFeatures = lastArgIndex; + } if (numFeatureNames != numFeatures) { - throw new UDFArgumentException("#featureNames '" + numFeatureNames - + "' != #arguments '" + numFeatures + "'"); + throw new UDFArgumentLengthException("#featureNames '" + numFeatureNames + + "' != #features '" + numFeatures + "'"); } - this.inputOIs = new PrimitiveObjectInspector[numFeatures]; + this._inputOIs = new PrimitiveObjectInspector[numFeatures]; for (int i = 0; i < numFeatures; i++) { ObjectInspector oi = argOIs[i + 1]; - inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); + _inputOIs[i] = HiveUtils.asPrimitiveObjectInspector(oi); } - this.result = new ArrayList(numFeatures); + this._result = new ArrayList(numFeatures); return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); } @Override public List evaluate(@Nonnull final DeferredObject[] arguments) throws HiveException { - result.clear(); + _result.clear(); - final int size = arguments.length - 1; + final int size = _featureNames.length; for (int i = 0; i < size; i++) { Object argument = arguments[i + 1].get(); if (argument == null) { + if (_emitNull) { + _result.add(null); + } continue; } - PrimitiveObjectInspector oi = inputOIs[i]; + PrimitiveObjectInspector oi = _inputOIs[i]; if (oi.getPrimitiveCategory() == PrimitiveCategory.STRING) { String s = PrimitiveObjectInspectorUtils.getString(argument, oi); if (s.isEmpty()) { + if (_emitNull) { + _result.add(null); + } continue; } - if (StringUtils.isNumber(s) == false) {// categorical feature representation - String featureName = featureNames[i]; - Text f = new Text(featureName + '#' + s); - result.add(f); + if (StringUtils.isNumber(s) == false) {// categorical feature representation + Text f = new Text(_featureNames[i] + '#' + s); + _result.add(f); continue; } } - float v = PrimitiveObjectInspectorUtils.getFloat(argument, oi); + final float v = PrimitiveObjectInspectorUtils.getFloat(argument, oi); if (v != 0.f) { - String featureName = featureNames[i]; - Text f = new Text(featureName + ':' + v); - result.add(f); + Text f = new Text(_featureNames[i] + ':' + v); + _result.add(f); + } else if (_emitNull) { + Text f = new Text(_featureNames[i] + ":0"); + _result.add(f); } } - return result; + return _result; } @Override diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java index a4bea00e9..1b7140f47 100644 --- a/core/src/main/java/hivemall/mf/FactorizedModel.java +++ b/core/src/main/java/hivemall/mf/FactorizedModel.java @@ -18,7 +18,7 @@ */ package hivemall.mf; -import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashTable; import hivemall.utils.math.MathUtils; import java.util.Random; @@ -42,10 +42,10 @@ public final class FactorizedModel { private int minIndex, maxIndex; @Nonnull private Rating meanRating; - private IntOpenHashMap users; - private IntOpenHashMap items; - private IntOpenHashMap userBias; - private IntOpenHashMap itemBias; + private IntOpenHashTable users; + private IntOpenHashTable items; + private IntOpenHashTable userBias; + private IntOpenHashTable itemBias; private final Random[] randU, randI; @@ -67,10 +67,10 @@ public FactorizedModel(@Nonnull RatingInitilizer ratingInitializer, @Nonnegative this.minIndex = 0; this.maxIndex = 0; this.meanRating = ratingInitializer.newRating(meanRating); - this.users = new IntOpenHashMap(expectedSize); - this.items = new IntOpenHashMap(expectedSize); - this.userBias = new IntOpenHashMap(expectedSize); - this.itemBias = new IntOpenHashMap(expectedSize); + this.users = new IntOpenHashTable(expectedSize); + this.items = new IntOpenHashTable(expectedSize); + this.userBias = new IntOpenHashTable(expectedSize); + this.itemBias = new IntOpenHashTable(expectedSize); this.randU = newRandoms(factor, 31L); this.randI = newRandoms(factor, 41L); } diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java index 95935d34c..cd298a785 100644 --- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java +++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java @@ -22,7 +22,7 @@ import hivemall.mix.MixedWeight; import hivemall.mix.MixedWeight.WeightWithCovar; import hivemall.mix.MixedWeight.WeightWithDelta; -import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashTable; import hivemall.utils.collections.maps.OpenHashMap; import javax.annotation.Nonnull; @@ -37,7 +37,7 @@ public abstract class AbstractPredictionModel implements PredictionModel { private long numMixed; private boolean cancelMixRequest; - private IntOpenHashMap mixedRequests_i; + private IntOpenHashTable mixedRequests_i; private OpenHashMap mixedRequests_o; public AbstractPredictionModel() { @@ -58,7 +58,7 @@ public void configureMix(@Nonnull ModelUpdateHandler handler, boolean cancelMixR this.cancelMixRequest = cancelMixRequest; if (cancelMixRequest) { if (isDenseModel()) { - this.mixedRequests_i = new IntOpenHashMap(327680); + this.mixedRequests_i = new IntOpenHashTable(327680); } else { this.mixedRequests_o = new OpenHashMap(327680); } diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java index 8326d22b7..5c0a6c7f0 100644 --- a/core/src/main/java/hivemall/model/NewSparseModel.java +++ b/core/src/main/java/hivemall/model/NewSparseModel.java @@ -194,7 +194,7 @@ public boolean contains(@Nonnull final Object feature) { @SuppressWarnings("unchecked") @Override public IMapIterator entries() { - return (IMapIterator) weights.entries(); + return (IMapIterator) weights.entries(true); } } diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index cb8ab9fb9..65e751d66 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -183,7 +183,7 @@ public boolean contains(@Nonnull final Object feature) { @SuppressWarnings("unchecked") @Override public IMapIterator entries() { - return (IMapIterator) weights.entries(); + return (IMapIterator) weights.entries(true); } } diff --git a/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java b/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java index a2e3e552f..6dbb7d569 100644 --- a/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java +++ b/core/src/main/java/hivemall/tools/array/ArrayAvgGenericUDAF.java @@ -18,6 +18,10 @@ */ package hivemall.tools.array; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_ARRAY_META; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.JAVA64_REF; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES1; +import static org.apache.hadoop.hive.ql.util.JavaDataModel.PRIMITIVES2; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; @@ -34,6 +38,7 @@ import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -220,7 +225,8 @@ public List terminate(@SuppressWarnings("deprecation") Aggregatio } } - public static class ArrayAvgAggregationBuffer extends AbstractAggregationBuffer { + @AggregationType(estimable = true) + public static final class ArrayAvgAggregationBuffer extends AbstractAggregationBuffer { int _size; // note that primitive array cannot be serialized by JDK serializer @@ -289,6 +295,15 @@ void merge(final int o_size, @Nonnull final Object o_sum, @Nonnull final Object } } + @Override + public int estimate() { + if (_size == -1) { + return JAVA64_REF; + } else { + return PRIMITIVES1 + 2 * (JAVA64_ARRAY_META + PRIMITIVES2 * _size); + } + } + } } diff --git a/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java b/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java index e0a3c9ec3..10051a9c8 100644 --- a/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java +++ b/core/src/main/java/hivemall/utils/buffer/HeapBuffer.java @@ -20,7 +20,6 @@ import hivemall.utils.lang.NumberUtils; import hivemall.utils.lang.Preconditions; -import hivemall.utils.lang.Primitives; import hivemall.utils.lang.SizeOf; import hivemall.utils.lang.UnsafeUtils; @@ -97,8 +96,8 @@ public long allocate(final int bytes) { Preconditions.checkArgument(bytes <= _chunkBytes, "Cannot allocate memory greater than %s bytes: %s", _chunkBytes, bytes); - int i = Primitives.castToInt(_position / _chunkBytes); - final int j = Primitives.castToInt(_position % _chunkBytes); + int i = NumberUtils.castToInt(_position / _chunkBytes); + final int j = NumberUtils.castToInt(_position % _chunkBytes); if (bytes > (_chunkBytes - j)) { // cannot allocate the object in the current chunk // so, skip the current chunk @@ -144,7 +143,7 @@ private void validatePointer(final long ptr) { public byte getByte(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getByte(chunk, j); @@ -152,7 +151,7 @@ public byte getByte(final long ptr) { public void putByte(final long ptr, final byte value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putByte(chunk, j, value); @@ -160,7 +159,7 @@ public void putByte(final long ptr, final byte value) { public int getInt(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getInt(chunk, j); @@ -168,7 +167,7 @@ public int getInt(final long ptr) { public void putInt(final long ptr, final int value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putInt(chunk, j, value); @@ -176,7 +175,7 @@ public void putInt(final long ptr, final int value) { public short getShort(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getShort(chunk, j); @@ -184,7 +183,7 @@ public short getShort(final long ptr) { public void putShort(final long ptr, final short value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putShort(chunk, j, value); @@ -192,7 +191,7 @@ public void putShort(final long ptr, final short value) { public char getChar(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getChar(chunk, j); @@ -200,14 +199,14 @@ public char getChar(final long ptr) { public void putChar(final long ptr, final char value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putChar(chunk, j, value); } public long getLong(final long ptr) { - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getLong(chunk, j); @@ -215,7 +214,7 @@ public long getLong(final long ptr) { public void putLong(final long ptr, final long value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putLong(chunk, j, value); @@ -223,7 +222,7 @@ public void putLong(final long ptr, final long value) { public float getFloat(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getFloat(chunk, j); @@ -231,7 +230,7 @@ public float getFloat(final long ptr) { public void putFloat(final long ptr, final float value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putFloat(chunk, j, value); @@ -239,7 +238,7 @@ public void putFloat(final long ptr, final float value) { public double getDouble(final long ptr) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); return _UNSAFE.getDouble(chunk, j); @@ -247,7 +246,7 @@ public double getDouble(final long ptr) { public void putDouble(final long ptr, final double value) { validatePointer(ptr); - int i = Primitives.castToInt(ptr / _chunkBytes); + int i = NumberUtils.castToInt(ptr / _chunkBytes); int[] chunk = _chunks[i]; long j = offset(ptr); _UNSAFE.putDouble(chunk, j, value); @@ -260,7 +259,7 @@ public void getFloats(final long ptr, @Nonnull final float[] values) { throw new IllegalArgumentException("Cannot put empty array at " + ptr); } - int chunkIdx = Primitives.castToInt(ptr / _chunkBytes); + int chunkIdx = NumberUtils.castToInt(ptr / _chunkBytes); final int[] chunk = _chunks[chunkIdx]; final long base = offset(ptr); for (int i = 0; i < len; i++) { @@ -277,7 +276,7 @@ public void putFloats(final long ptr, @Nonnull final float[] values) { throw new IllegalArgumentException("Cannot put empty array at " + ptr); } - int chunkIdx = Primitives.castToInt(ptr / _chunkBytes); + int chunkIdx = NumberUtils.castToInt(ptr / _chunkBytes); final int[] chunk = _chunks[chunkIdx]; final long base = offset(ptr); for (int i = 0; i < len; i++) { diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java index f847b1516..e9b5c8a97 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2FloatOpenHashTable.java @@ -27,8 +27,13 @@ import java.util.Arrays; /** - * An open-addressing hash table with double hashing - * + * An open-addressing hash table using double hashing. + * + *
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
+ * * @see http://en.wikipedia.org/wiki/Double_hashing */ public class Int2FloatOpenHashTable implements Externalizable { @@ -37,7 +42,7 @@ public class Int2FloatOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java index 5e9e81232..8e87fcee1 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2IntOpenHashTable.java @@ -27,7 +27,12 @@ import java.util.Arrays; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + *
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
* * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -37,7 +42,7 @@ public final class Int2IntOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java new file mode 100644 index 000000000..ffa80d0d4 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashMap.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// +// Copyright (C) 2010 catchpole.net +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +package hivemall.utils.collections.maps; + +import hivemall.utils.hashing.HashUtils; +import hivemall.utils.math.MathUtils; + +import java.util.Arrays; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * A space efficient open-addressing HashMap implementation with integer keys and long values. + * + * Unlike {@link Int2LongOpenHashTable}, it maintains single arrays for keys and object references. + * + * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those + * divisible by prime numbers. This allows the hash offset calculation to be a simple binary masking + * operation. + * + * The index into the arrays is determined by masking a portion of the key and shifting it to + * provide a series of small buckets within the array. To insert an entry the a sweep is searched + * until an empty key space is found. A sweep is 4 times the length of a bucket, to reduce the need + * to rehash. If no key space is found within a sweep, the table size is doubled. + * + * While performance is high, the slowest situation is where lookup occurs for entries that do not + * exist, as an entire sweep area must be searched. However, this HashMap is more space efficient + * than other open-addressing HashMap implementations as in fastutil. + */ +@NotThreadSafe +public final class Int2LongOpenHashMap { + + // special treatment for key=0 + private boolean hasKey0 = false; + private long value0 = 0L; + + private int[] keys; + private long[] values; + + // total number of entries in this table + private int size; + // number of bits for the value table (eg. 8 bits = 256 entries) + private int bits; + // the number of bits in each sweep zone. + private int sweepbits; + // the size of a sweep (2 to the power of sweepbits) + private int sweep; + // the sweepmask used to create sweep zone offsets + private int sweepmask; + + public Int2LongOpenHashMap(int size) { + resize(MathUtils.bitsRequired(size < 256 ? 256 : size)); + } + + public long put(final int key, final long value) { + if (key == 0) { + if (!hasKey0) { + this.hasKey0 = true; + size++; + } + long old = value0; + this.value0 = value; + return old; + } + + for (;;) { + int off = getBucketOffset(key); + final int end = off + sweep; + for (; off < end; off++) { + final int searchKey = keys[off]; + if (searchKey == 0) { // insert + keys[off] = key; + size++; + long previous = values[off]; + values[off] = value; + return previous; + } else if (searchKey == key) {// replace + long previous = values[off]; + values[off] = value; + return previous; + } + } + resize(this.bits + 1); + } + } + + public long putIfAbsent(final int key, final long value) { + if (key == 0) { + if (hasKey0) { + return value0; + } + this.hasKey0 = true; + long old = value0; + this.value0 = value; + size++; + return old; + } + + for (;;) { + int off = getBucketOffset(key); + final int end = off + sweep; + for (; off < end; off++) { + final int searchKey = keys[off]; + if (searchKey == 0) { // insert + keys[off] = key; + size++; + long previous = values[off]; + values[off] = value; + return previous; + } else if (searchKey == key) {// replace + return values[off]; + } + } + resize(this.bits + 1); + } + } + + public long get(final int key) { + return get(key, 0L); + } + + public long get(final int key, final long defaultValue) { + if (key == 0) { + return hasKey0 ? value0 : defaultValue; + } + + int off = getBucketOffset(key); + final int end = sweep + off; + for (; off < end; off++) { + if (keys[off] == key) { + return values[off]; + } + } + return defaultValue; + } + + public long remove(final int key, final long defaultValue) { + if (key == 0) { + if (hasKey0) { + this.hasKey0 = false; + long old = value0; + this.value0 = 0L; + size--; + return old; + } else { + return defaultValue; + } + } + + int off = getBucketOffset(key); + final int end = sweep + off; + for (; off < end; off++) { + if (keys[off] == key) { + keys[off] = 0; + long previous = values[off]; + values[off] = 0L; + size--; + return previous; + } + } + return defaultValue; + } + + public int size() { + return size; + } + + public boolean isEmpty() { + return size == 0; + } + + public boolean containsKey(final int key) { + if (key == 0) { + return hasKey0; + } + + int off = getBucketOffset(key); + final int end = sweep + off; + for (; off < end; off++) { + if (keys[off] == key) { + return true; + } + } + return false; + } + + public void clear() { + this.hasKey0 = false; + this.value0 = 0L; + Arrays.fill(keys, 0); + Arrays.fill(values, 0L); + this.size = 0; + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + ' ' + size; + } + + private void resize(final int bits) { + this.bits = bits; + this.sweepbits = bits / 4; + this.sweep = MathUtils.powerOf(2, sweepbits) * 4; + this.sweepmask = MathUtils.bitMask(bits - sweepbits) << sweepbits; + + // remember old values so we can recreate the entries + final int[] existingKeys = this.keys; + final long[] existingValues = this.values; + + // create the arrays + this.values = new long[MathUtils.powerOf(2, bits) + sweep]; + this.keys = new int[values.length]; + this.size = hasKey0 ? 1 : 0; + + // re-add the previous entries if resizing + if (existingKeys != null) { + for (int i = 0; i < existingKeys.length; i++) { + final int k = existingKeys[i]; + if (k != 0) { + put(k, existingValues[i]); + } + } + } + } + + private int getBucketOffset(final int key) { + return (HashUtils.fnv1a(key) << sweepbits) & sweepmask; + } + + @Nonnull + public MapIterator entries() { + return new MapIterator(); + } + + public final class MapIterator { + + int nextEntry; + int lastEntry = -2; + + MapIterator() { + this.nextEntry = nextEntry(-1); + } + + /** find the index of next full entry */ + int nextEntry(int index) { + if (index == -1) { + if (hasKey0) { + return -1; + } else { + index = 0; + } + } + while (index < keys.length && keys[index] == 0) { + index++; + } + return index; + } + + public boolean hasNext() { + return nextEntry < keys.length; + } + + public boolean next() { + free(lastEntry); + if (!hasNext()) { + return false; + } + int curEntry = nextEntry; + this.lastEntry = curEntry; + this.nextEntry = nextEntry(curEntry + 1); + return true; + } + + public int getKey() { + if (lastEntry >= 0 && lastEntry < keys.length) { + return keys[lastEntry]; + } else if (lastEntry == -1) { + return 0; + } else { + throw new IllegalStateException( + "next() should be called before getKey(). lastEntry=" + lastEntry + + ", keys.length=" + keys.length); + } + } + + public long getValue() { + if (lastEntry >= 0 && lastEntry < keys.length) { + return values[lastEntry]; + } else if (lastEntry == -1) { + return value0; + } else { + throw new IllegalStateException( + "next() should be called before getKey(). lastEntry=" + lastEntry + + ", keys.length=" + keys.length); + } + } + + private void free(int index) { + if (index >= 0) { + if (index >= keys.length) { + throw new IllegalStateException("index=" + index + ", keys.length=" + + keys.length); + } + keys[index] = 0; + values[index] = 0L; + } else if (index == -1) { + hasKey0 = false; + value0 = 0L; + } + // index may be -2 + } + + } +} diff --git a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java index 68eb42fe9..22acdb4a0 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Int2LongOpenHashTable.java @@ -33,7 +33,12 @@ import javax.annotation.Nonnull; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + *
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
* * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -44,7 +49,7 @@ public class Int2LongOpenHashTable implements Externalizable { protected static final byte REMOVED = 2; public static final int DEFAULT_SIZE = 65536; - public static final float DEFAULT_LOAD_FACTOR = 0.7f; + public static final float DEFAULT_LOAD_FACTOR = 0.75f; public static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; @@ -123,23 +128,23 @@ public byte[] getStates() { return _states; } - public boolean containsKey(int key) { + public boolean containsKey(final int key) { return findKey(key) >= 0; } /** * @return -1.f if not found */ - public long get(int key) { - int i = findKey(key); + public long get(final int key) { + final int i = findKey(key); if (i < 0) { return defaultReturnValue; } return _values[i]; } - public long put(int key, long value) { - int hash = keyHash(key); + public long put(final int key, final long value) { + final int hash = keyHash(key); int keyLength = _keys.length; int keyIdx = hash % keyLength; @@ -149,9 +154,9 @@ public long put(int key, long value) { keyIdx = hash % keyLength; } - int[] keys = _keys; - long[] values = _values; - byte[] states = _states; + final int[] keys = _keys; + final long[] values = _values; + final byte[] states = _states; if (states[keyIdx] == FULL) {// double hashing if (keys[keyIdx] == key) { @@ -160,7 +165,7 @@ public long put(int key, long value) { return old; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -184,8 +189,8 @@ public long put(int key, long value) { } /** Return weather the required slot is free for new entry */ - protected boolean isFree(int index, int key) { - byte stat = _states[index]; + protected boolean isFree(final int index, final int key) { + final byte stat = _states[index]; if (stat == FREE) { return true; } @@ -196,7 +201,7 @@ protected boolean isFree(int index, int key) { } /** @return expanded or not */ - protected boolean preAddEntry(int index) { + protected boolean preAddEntry(final int index) { if ((_used + 1) >= _threshold) {// too filled int newCapacity = Math.round(_keys.length * _growFactor); ensureCapacity(newCapacity); @@ -205,19 +210,19 @@ protected boolean preAddEntry(int index) { return false; } - protected int findKey(int key) { - int[] keys = _keys; - byte[] states = _states; - int keyLength = keys.length; + protected int findKey(final int key) { + final int[] keys = _keys; + final byte[] states = _states; + final int keyLength = keys.length; - int hash = keyHash(key); + final int hash = keyHash(key); int keyIdx = hash % keyLength; if (states[keyIdx] != FREE) { if (states[keyIdx] == FULL && keys[keyIdx] == key) { return keyIdx; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -234,13 +239,13 @@ protected int findKey(int key) { return -1; } - public long remove(int key) { - int[] keys = _keys; - long[] values = _values; - byte[] states = _states; - int keyLength = keys.length; + public long remove(final int key) { + final int[] keys = _keys; + final long[] values = _values; + final byte[] states = _states; + final int keyLength = keys.length; - int hash = keyHash(key); + final int hash = keyHash(key); int keyIdx = hash % keyLength; if (states[keyIdx] != FREE) { if (states[keyIdx] == FULL && keys[keyIdx] == key) { @@ -250,7 +255,7 @@ public long remove(int key) { return old; } // second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -283,21 +288,22 @@ public void clear() { this._used = 0; } - public IMapIterator entries() { + @Nonnull + public MapIterator entries() { return new MapIterator(); } @Override public String toString() { int len = size() * 10 + 2; - StringBuilder buf = new StringBuilder(len); + final StringBuilder buf = new StringBuilder(len); buf.append('{'); - IMapIterator i = entries(); - while (i.next() != -1) { - buf.append(i.getKey()); + final MapIterator itor = entries(); + while (itor.next() != -1) { + buf.append(itor.getKey()); buf.append('='); - buf.append(i.getValue()); - if (i.hasNext()) { + buf.append(itor.getValue()); + if (itor.hasNext()) { buf.append(','); } } @@ -305,30 +311,30 @@ public String toString() { return buf.toString(); } - protected void ensureCapacity(int newCapacity) { + protected void ensureCapacity(final int newCapacity) { int prime = Primes.findLeastPrimeNumber(newCapacity); rehash(prime); this._threshold = Math.round(prime * _loadFactor); } - private void rehash(int newCapacity) { + private void rehash(final int newCapacity) { int oldCapacity = _keys.length; if (newCapacity <= oldCapacity) { throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); } - int[] newkeys = new int[newCapacity]; - long[] newValues = new long[newCapacity]; - byte[] newStates = new byte[newCapacity]; + final int[] newkeys = new int[newCapacity]; + final long[] newValues = new long[newCapacity]; + final byte[] newStates = new byte[newCapacity]; int used = 0; for (int i = 0; i < oldCapacity; i++) { if (_states[i] == FULL) { used++; - int k = _keys[i]; - long v = _values[i]; - int hash = keyHash(k); + final int k = _keys[i]; + final long v = _values[i]; + final int hash = keyHash(k); int keyIdx = hash % newCapacity; if (newStates[keyIdx] == FULL) {// second hashing - int decr = 1 + (hash % (newCapacity - 2)); + final int decr = 1 + (hash % (newCapacity - 2)); while (newStates[keyIdx] != FREE) { keyIdx -= decr; if (keyIdx < 0) { @@ -347,7 +353,7 @@ private void rehash(int newCapacity) { this._used = used; } - private static int keyHash(int key) { + private static int keyHash(final int key) { return key & 0x7fffffff; } @@ -437,22 +443,7 @@ private static void readStates(@Nonnull final DataInput in, @Nonnull final byte[ } } - public interface IMapIterator { - - public boolean hasNext(); - - /** - * @return -1 if not found - */ - public int next(); - - public int getKey(); - - public long getValue(); - - } - - private final class MapIterator implements IMapIterator { + public final class MapIterator { int nextEntry; int lastEntry = -1; @@ -473,6 +464,9 @@ public boolean hasNext() { return nextEntry < _keys.length; } + /** + * @return -1 if not found + */ public int next() { if (!hasNext()) { return -1; diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java deleted file mode 100644 index 5ce34a498..000000000 --- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashMap.java +++ /dev/null @@ -1,467 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import hivemall.utils.math.Primes; - -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.Arrays; - -/** - * An open-addressing hash table with double hashing - * - * @see http://en.wikipedia.org/wiki/Double_hashing - */ -public class IntOpenHashMap implements Externalizable { - private static final long serialVersionUID = -8162355845665353513L; - - protected static final byte FREE = 0; - protected static final byte FULL = 1; - protected static final byte REMOVED = 2; - - private static final float DEFAULT_LOAD_FACTOR = 0.7f; - private static final float DEFAULT_GROW_FACTOR = 2.0f; - - protected final transient float _loadFactor; - protected final transient float _growFactor; - - protected int _used = 0; - protected int _threshold; - - protected int[] _keys; - protected V[] _values; - protected byte[] _states; - - @SuppressWarnings("unchecked") - protected IntOpenHashMap(int size, float loadFactor, float growFactor, boolean forcePrime) { - if (size < 1) { - throw new IllegalArgumentException(); - } - this._loadFactor = loadFactor; - this._growFactor = growFactor; - int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size; - this._keys = new int[actualSize]; - this._values = (V[]) new Object[actualSize]; - this._states = new byte[actualSize]; - this._threshold = Math.round(actualSize * _loadFactor); - } - - public IntOpenHashMap(int size, float loadFactor, float growFactor) { - this(size, loadFactor, growFactor, true); - } - - public IntOpenHashMap(int size) { - this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); - } - - public IntOpenHashMap() {// required for serialization - this._loadFactor = DEFAULT_LOAD_FACTOR; - this._growFactor = DEFAULT_GROW_FACTOR; - } - - public boolean containsKey(int key) { - return findKey(key) >= 0; - } - - public final V get(final int key) { - final int i = findKey(key); - if (i < 0) { - return null; - } - recordAccess(i); - return _values[i]; - } - - public V put(final int key, final V value) { - final int hash = keyHash(key); - int keyLength = _keys.length; - int keyIdx = hash % keyLength; - - final boolean expanded = preAddEntry(keyIdx); - if (expanded) { - keyLength = _keys.length; - keyIdx = hash % keyLength; - } - - final int[] keys = _keys; - final V[] values = _values; - final byte[] states = _states; - - if (states[keyIdx] == FULL) {// double hashing - if (keys[keyIdx] == key) { - V old = values[keyIdx]; - values[keyIdx] = value; - recordAccess(keyIdx); - return old; - } - // try second hash - final int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (isFree(keyIdx, key)) { - break; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - V old = values[keyIdx]; - values[keyIdx] = value; - recordAccess(keyIdx); - return old; - } - } - } - keys[keyIdx] = key; - values[keyIdx] = value; - states[keyIdx] = FULL; - ++_used; - postAddEntry(keyIdx); - return null; - } - - public V putIfAbsent(final int key, final V value) { - final int hash = keyHash(key); - int keyLength = _keys.length; - int keyIdx = hash % keyLength; - - final boolean expanded = preAddEntry(keyIdx); - if (expanded) { - keyLength = _keys.length; - keyIdx = hash % keyLength; - } - - final int[] keys = _keys; - final V[] values = _values; - final byte[] states = _states; - - if (states[keyIdx] == FULL) {// second hashing - if (keys[keyIdx] == key) { - return values[keyIdx]; - } - // try second hash - final int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (isFree(keyIdx, key)) { - break; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - return values[keyIdx]; - } - } - } - keys[keyIdx] = key; - values[keyIdx] = value; - states[keyIdx] = FULL; - _used++; - postAddEntry(keyIdx); - return null; - } - - /** Return weather the required slot is free for new entry */ - protected boolean isFree(int index, int key) { - byte stat = _states[index]; - if (stat == FREE) { - return true; - } - if (stat == REMOVED && _keys[index] == key) { - return true; - } - return false; - } - - /** @return expanded or not */ - protected boolean preAddEntry(int index) { - if ((_used + 1) >= _threshold) {// too filled - int newCapacity = Math.round(_keys.length * _growFactor); - ensureCapacity(newCapacity); - return true; - } - return false; - } - - protected void postAddEntry(int index) {} - - private int findKey(int key) { - int[] keys = _keys; - byte[] states = _states; - int keyLength = keys.length; - - int hash = keyHash(key); - int keyIdx = hash % keyLength; - if (states[keyIdx] != FREE) { - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - return keyIdx; - } - // try second hash - int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (isFree(keyIdx, key)) { - return -1; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - return keyIdx; - } - } - } - return -1; - } - - public V remove(int key) { - int[] keys = _keys; - V[] values = _values; - byte[] states = _states; - int keyLength = keys.length; - - int hash = keyHash(key); - int keyIdx = hash % keyLength; - if (states[keyIdx] != FREE) { - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - V old = values[keyIdx]; - states[keyIdx] = REMOVED; - --_used; - recordRemoval(keyIdx); - return old; - } - // second hash - int decr = 1 + (hash % (keyLength - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keyLength; - } - if (states[keyIdx] == FREE) { - return null; - } - if (states[keyIdx] == FULL && keys[keyIdx] == key) { - V old = values[keyIdx]; - states[keyIdx] = REMOVED; - --_used; - recordRemoval(keyIdx); - return old; - } - } - } - return null; - } - - public int size() { - return _used; - } - - public void clear() { - Arrays.fill(_states, FREE); - this._used = 0; - } - - @SuppressWarnings("unchecked") - public IMapIterator entries() { - return new MapIterator(); - } - - @Override - public String toString() { - int len = size() * 10 + 2; - StringBuilder buf = new StringBuilder(len); - buf.append('{'); - IMapIterator i = entries(); - while (i.next() != -1) { - buf.append(i.getKey()); - buf.append('='); - buf.append(i.getValue()); - if (i.hasNext()) { - buf.append(','); - } - } - buf.append('}'); - return buf.toString(); - } - - private void ensureCapacity(int newCapacity) { - int prime = Primes.findLeastPrimeNumber(newCapacity); - rehash(prime); - this._threshold = Math.round(prime * _loadFactor); - } - - @SuppressWarnings("unchecked") - protected void rehash(int newCapacity) { - int oldCapacity = _keys.length; - if (newCapacity <= oldCapacity) { - throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); - } - final int[] oldKeys = _keys; - final V[] oldValues = _values; - final byte[] oldStates = _states; - int[] newkeys = new int[newCapacity]; - V[] newValues = (V[]) new Object[newCapacity]; - byte[] newStates = new byte[newCapacity]; - int used = 0; - for (int i = 0; i < oldCapacity; i++) { - if (oldStates[i] == FULL) { - used++; - int k = oldKeys[i]; - V v = oldValues[i]; - int hash = keyHash(k); - int keyIdx = hash % newCapacity; - if (newStates[keyIdx] == FULL) {// second hashing - int decr = 1 + (hash % (newCapacity - 2)); - while (newStates[keyIdx] != FREE) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += newCapacity; - } - } - } - newkeys[keyIdx] = k; - newValues[keyIdx] = v; - newStates[keyIdx] = FULL; - } - } - this._keys = newkeys; - this._values = newValues; - this._states = newStates; - this._used = used; - } - - private static int keyHash(int key) { - return key & 0x7fffffff; - } - - protected void recordAccess(int idx) {} - - protected void recordRemoval(int idx) {} - - public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(_threshold); - out.writeInt(_used); - - out.writeInt(_keys.length); - IMapIterator i = entries(); - while (i.next() != -1) { - out.writeInt(i.getKey()); - out.writeObject(i.getValue()); - } - } - - @SuppressWarnings("unchecked") - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this._threshold = in.readInt(); - this._used = in.readInt(); - - int keylen = in.readInt(); - int[] keys = new int[keylen]; - V[] values = (V[]) new Object[keylen]; - byte[] states = new byte[keylen]; - for (int i = 0; i < _used; i++) { - int k = in.readInt(); - V v = (V) in.readObject(); - int hash = keyHash(k); - int keyIdx = hash % keylen; - if (states[keyIdx] != FREE) {// second hash - int decr = 1 + (hash % (keylen - 2)); - for (;;) { - keyIdx -= decr; - if (keyIdx < 0) { - keyIdx += keylen; - } - if (states[keyIdx] == FREE) { - break; - } - } - } - states[keyIdx] = FULL; - keys[keyIdx] = k; - values[keyIdx] = v; - } - this._keys = keys; - this._values = values; - this._states = states; - } - - public interface IMapIterator { - - public boolean hasNext(); - - public int next(); - - public int getKey(); - - public V getValue(); - - } - - @SuppressWarnings("rawtypes") - private final class MapIterator implements IMapIterator { - - int nextEntry; - int lastEntry = -1; - - MapIterator() { - this.nextEntry = nextEntry(0); - } - - /** find the index of next full entry */ - int nextEntry(int index) { - while (index < _keys.length && _states[index] != FULL) { - index++; - } - return index; - } - - public boolean hasNext() { - return nextEntry < _keys.length; - } - - public int next() { - if (!hasNext()) { - return -1; - } - int curEntry = nextEntry; - this.lastEntry = curEntry; - this.nextEntry = nextEntry(curEntry + 1); - return curEntry; - } - - public int getKey() { - if (lastEntry == -1) { - throw new IllegalStateException(); - } - return _keys[lastEntry]; - } - - public V getValue() { - if (lastEntry == -1) { - throw new IllegalStateException(); - } - return _values[lastEntry]; - } - } - -} diff --git a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java index dcb64d12b..dbade7499 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/IntOpenHashTable.java @@ -25,54 +25,68 @@ import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; -import java.util.HashMap; import javax.annotation.Nonnull; /** - * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}. + * An open-addressing hash table using double hashing. + * + *
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
+ * + * @see http://en.wikipedia.org/wiki/Double_hashing */ public final class IntOpenHashTable implements Externalizable { + private static final long serialVersionUID = -8162355845665353513L; - public static final float DEFAULT_LOAD_FACTOR = 0.7f; + public static final float DEFAULT_LOAD_FACTOR = 0.75f; public static final float DEFAULT_GROW_FACTOR = 2.0f; - public static final byte FREE = 0; - public static final byte FULL = 1; - public static final byte REMOVED = 2; + protected static final byte FREE = 0; + protected static final byte FULL = 1; + protected static final byte REMOVED = 2; protected/* final */float _loadFactor; protected/* final */float _growFactor; - protected int _used = 0; + protected int _used; protected int _threshold; protected int[] _keys; protected V[] _values; protected byte[] _states; - public IntOpenHashTable() {} // for Externalizable + public IntOpenHashTable() {} // for Externalizable public IntOpenHashTable(int size) { - this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR); + this(size, DEFAULT_LOAD_FACTOR, DEFAULT_GROW_FACTOR, true); } - @SuppressWarnings("unchecked") public IntOpenHashTable(int size, float loadFactor, float growFactor) { + this(size, loadFactor, growFactor, true); + } + + @SuppressWarnings("unchecked") + protected IntOpenHashTable(int size, float loadFactor, float growFactor, boolean forcePrime) { if (size < 1) { throw new IllegalArgumentException(); } this._loadFactor = loadFactor; this._growFactor = growFactor; - int actualSize = Primes.findLeastPrimeNumber(size); + this._used = 0; + int actualSize = forcePrime ? Primes.findLeastPrimeNumber(size) : size; + this._threshold = Math.round(actualSize * _loadFactor); this._keys = new int[actualSize]; this._values = (V[]) new Object[actualSize]; this._states = new byte[actualSize]; - this._threshold = Math.round(actualSize * _loadFactor); } public IntOpenHashTable(@Nonnull int[] keys, @Nonnull V[] values, @Nonnull byte[] states, int used) { + this._loadFactor = DEFAULT_LOAD_FACTOR; + this._growFactor = DEFAULT_GROW_FACTOR; this._used = used; this._threshold = keys.length; this._keys = keys; @@ -80,14 +94,17 @@ public IntOpenHashTable(@Nonnull int[] keys, @Nonnull V[] values, @Nonnull byte[ this._states = states; } + @Nonnull public int[] getKeys() { return _keys; } + @Nonnull public Object[] getValues() { return _values; } + @Nonnull public byte[] getStates() { return _states; } @@ -109,7 +126,7 @@ public V put(final int key, final V value) { int keyLength = _keys.length; int keyIdx = hash % keyLength; - boolean expanded = preAddEntry(keyIdx); + final boolean expanded = preAddEntry(keyIdx); if (expanded) { keyLength = _keys.length; keyIdx = hash % keyLength; @@ -119,14 +136,14 @@ public V put(final int key, final V value) { final V[] values = _values; final byte[] states = _states; - if (states[keyIdx] == FULL) { + if (states[keyIdx] == FULL) {// double hashing if (keys[keyIdx] == key) { V old = values[keyIdx]; values[keyIdx] = value; return old; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -149,10 +166,50 @@ public V put(final int key, final V value) { return null; } + public V putIfAbsent(final int key, final V value) { + final int hash = keyHash(key); + int keyLength = _keys.length; + int keyIdx = hash % keyLength; + + final boolean expanded = preAddEntry(keyIdx); + if (expanded) { + keyLength = _keys.length; + keyIdx = hash % keyLength; + } + + final int[] keys = _keys; + final V[] values = _values; + final byte[] states = _states; + + if (states[keyIdx] == FULL) {// second hashing + if (keys[keyIdx] == key) { + return values[keyIdx]; + } + // try second hash + final int decr = 1 + (hash % (keyLength - 2)); + for (;;) { + keyIdx -= decr; + if (keyIdx < 0) { + keyIdx += keyLength; + } + if (isFree(keyIdx, key)) { + break; + } + if (states[keyIdx] == FULL && keys[keyIdx] == key) { + return values[keyIdx]; + } + } + } + keys[keyIdx] = key; + values[keyIdx] = value; + states[keyIdx] = FULL; + _used++; + return null; + } /** Return weather the required slot is free for new entry */ - protected boolean isFree(int index, int key) { - byte stat = _states[index]; + protected boolean isFree(final int index, final int key) { + final byte stat = _states[index]; if (stat == FREE) { return true; } @@ -163,8 +220,8 @@ protected boolean isFree(int index, int key) { } /** @return expanded or not */ - protected boolean preAddEntry(int index) { - if ((_used + 1) >= _threshold) {// filled enough + protected boolean preAddEntry(final int index) { + if ((_used + 1) >= _threshold) {// too filled int newCapacity = Math.round(_keys.length * _growFactor); ensureCapacity(newCapacity); return true; @@ -172,7 +229,7 @@ protected boolean preAddEntry(int index) { return false; } - protected int findKey(final int key) { + private int findKey(final int key) { final int[] keys = _keys; final byte[] states = _states; final int keyLength = keys.length; @@ -184,7 +241,7 @@ protected int findKey(final int key) { return keyIdx; } // try second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -217,7 +274,7 @@ public V remove(final int key) { return old; } // second hash - int decr = 1 + (hash % (keyLength - 2)); + final int decr = 1 + (hash % (keyLength - 2)); for (;;) { keyIdx -= decr; if (keyIdx < 0) { @@ -255,28 +312,49 @@ public void clear() { this._used = 0; } - protected void ensureCapacity(int newCapacity) { + @Override + public String toString() { + int len = size() * 10 + 2; + final StringBuilder buf = new StringBuilder(len); + buf.append('{'); + final IMapIterator i = entries(); + while (i.next() != -1) { + buf.append(i.getKey()); + buf.append('='); + buf.append(i.getValue()); + if (i.hasNext()) { + buf.append(','); + } + } + buf.append('}'); + return buf.toString(); + } + + private void ensureCapacity(final int newCapacity) { int prime = Primes.findLeastPrimeNumber(newCapacity); rehash(prime); this._threshold = Math.round(prime * _loadFactor); } @SuppressWarnings("unchecked") - private void rehash(int newCapacity) { + private void rehash(final int newCapacity) { int oldCapacity = _keys.length; if (newCapacity <= oldCapacity) { throw new IllegalArgumentException("new: " + newCapacity + ", old: " + oldCapacity); } + final int[] oldKeys = _keys; + final V[] oldValues = _values; + final byte[] oldStates = _states; final int[] newkeys = new int[newCapacity]; final V[] newValues = (V[]) new Object[newCapacity]; final byte[] newStates = new byte[newCapacity]; int used = 0; for (int i = 0; i < oldCapacity; i++) { - if (_states[i] == FULL) { + if (oldStates[i] == FULL) { used++; - int k = _keys[i]; - V v = _values[i]; - int hash = keyHash(k); + final int k = oldKeys[i]; + final V v = oldValues[i]; + final int hash = keyHash(k); int keyIdx = hash % newCapacity; if (newStates[keyIdx] == FULL) {// second hashing int decr = 1 + (hash % (newCapacity - 2)); @@ -287,9 +365,9 @@ private void rehash(int newCapacity) { } } } - newStates[keyIdx] = FULL; newkeys[keyIdx] = k; newValues[keyIdx] = v; + newStates[keyIdx] = FULL; } } this._keys = newkeys; @@ -303,7 +381,7 @@ private static int keyHash(final int key) { } @Override - public void writeExternal(ObjectOutput out) throws IOException { + public void writeExternal(@Nonnull final ObjectOutput out) throws IOException { out.writeFloat(_loadFactor); out.writeFloat(_growFactor); out.writeInt(_used); @@ -319,8 +397,8 @@ public void writeExternal(ObjectOutput out) throws IOException { } @SuppressWarnings("unchecked") - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + public void readExternal(@Nonnull final ObjectInput in) throws IOException, + ClassNotFoundException { this._loadFactor = in.readFloat(); this._growFactor = in.readFloat(); this._used = in.readInt(); diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java index c758824bc..b4356ff20 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2DoubleOpenHashTable.java @@ -27,7 +27,12 @@ import java.util.Arrays; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + *
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
* * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -37,7 +42,7 @@ public final class Long2DoubleOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java index 6a7f39f30..6b0ab59e3 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2FloatOpenHashTable.java @@ -27,9 +27,14 @@ import java.util.Arrays; /** - * An open-addressing hash table with float hashing + * An open-addressing hash table using double hashing. + * + *
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
* - * @see http://en.wikipedia.org/wiki/float_hashing + * @see http://en.wikipedia.org/wiki/Double_hashing */ public final class Long2FloatOpenHashTable implements Externalizable { @@ -37,7 +42,7 @@ public final class Long2FloatOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; diff --git a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java index 51b8f1294..1ca4c4023 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/Long2IntOpenHashTable.java @@ -27,7 +27,12 @@ import java.util.Arrays; /** - * An open-addressing hash table with double hashing + * An open-addressing hash table using double hashing. + * + *
+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
* * @see http://en.wikipedia.org/wiki/Double_hashing */ @@ -37,7 +42,7 @@ public final class Long2IntOpenHashTable implements Externalizable { protected static final byte FULL = 1; protected static final byte REMOVED = 2; - private static final float DEFAULT_LOAD_FACTOR = 0.7f; + private static final float DEFAULT_LOAD_FACTOR = 0.75f; private static final float DEFAULT_GROW_FACTOR = 2.0f; protected final transient float _loadFactor; diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java index 152447a02..f5ee1e69c 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java +++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashMap.java @@ -48,16 +48,29 @@ import java.util.Map; import java.util.Set; +import javax.annotation.CheckForNull; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + /** - * An optimized Hashed Map implementation. - *

- *

- * This Hashmap does not allow nulls to be used as keys or values. - *

- *

+ * A space efficient open-addressing HashMap implementation. + * + * Unlike {@link OpenHashTable}, it maintains single arrays for keys and object references. + * * It uses single open hashing arrays sized to binary powers (256, 512 etc) rather than those - * divisable by prime numbers. This allows the hash offset calculation to be a simple binary masking + * divisible by prime numbers. This allows the hash offset calculation to be a simple binary masking * operation. + * + * The index into the arrays is determined by masking a portion of the key and shifting it to + * provide a series of small buckets within the array. To insert an entry the a sweep is searched + * until an empty key space is found. A sweep is 4 times the length of a bucket, to reduce the need + * to rehash. If no key space is found within a sweep, the table size is doubled. + * + * While performance is high, the slowest situation is where lookup occurs for entries that do not + * exist, as an entire sweep area must be searched. However, this HashMap is more space efficient + * than other open-addressing HashMap implementations as in fastutil. + * + * Note that this HashMap does not allow nulls to be used as keys. */ public final class OpenHashMap implements Map, Externalizable { private K[] keys; @@ -80,21 +93,21 @@ public OpenHashMap(int size) { resize(MathUtils.bitsRequired(size < 256 ? 256 : size)); } - public V put(K key, V value) { + @Nullable + public V put(@CheckForNull final K key, @Nullable final V value) { if (key == null) { throw new NullPointerException(this.getClass().getName() + " key"); } for (;;) { int off = getBucketOffset(key); - int end = off + sweep; + final int end = off + sweep; for (; off < end; off++) { - K searchKey = keys[off]; + final K searchKey = keys[off]; if (searchKey == null) { // insert keys[off] = key; size++; - V previous = values[off]; values[off] = value; return previous; @@ -109,9 +122,36 @@ public V put(K key, V value) { } } - public V get(Object key) { + @Nullable + public V putIfAbsent(@CheckForNull final K key, @Nullable final V value) { + if (key == null) { + throw new NullPointerException(this.getClass().getName() + " key"); + } + + for (;;) { + int off = getBucketOffset(key); + final int end = off + sweep; + for (; off < end; off++) { + final K searchKey = keys[off]; + if (searchKey == null) { + // insert + keys[off] = key; + size++; + V previous = values[off]; + values[off] = value; + return previous; + } else if (compare(searchKey, key)) { + return values[off]; + } + } + resize(this.bits + 1); + } + } + + @Nullable + public V get(@Nonnull final Object key) { int off = getBucketOffset(key); - int end = sweep + off; + final int end = sweep + off; for (; off < end; off++) { if (keys[off] != null && compare(keys[off], key)) { return values[off]; @@ -120,9 +160,10 @@ public V get(Object key) { return null; } - public V remove(Object key) { + @Nullable + public V remove(@Nonnull final Object key) { int off = getBucketOffset(key); - int end = sweep + off; + final int end = sweep + off; for (; off < end; off++) { if (keys[off] != null && compare(keys[off], key)) { keys[off] = null; @@ -139,7 +180,7 @@ public int size() { return size; } - public void putAll(Map m) { + public void putAll(@Nonnull final Map m) { for (K key : m.keySet()) { put(key, m.get(key)); } @@ -149,11 +190,11 @@ public boolean isEmpty() { return size == 0; } - public boolean containsKey(Object key) { + public boolean containsKey(@Nonnull final Object key) { return get(key) != null; } - public boolean containsValue(Object value) { + public boolean containsValue(@Nonnull final Object value) { for (V v : values) { if (v != null && compare(v, value)) { return true; @@ -165,11 +206,12 @@ public boolean containsValue(Object value) { public void clear() { Arrays.fill(keys, null); Arrays.fill(values, null); - size = 0; + this.size = 0; } + @Nonnull public Set keySet() { - Set set = new HashSet(); + final Set set = new HashSet(); for (K key : keys) { if (key != null) { set.add(key); @@ -178,8 +220,9 @@ public Set keySet() { return set; } + @Nonnull public Collection values() { - Collection list = new ArrayList(); + final Collection list = new ArrayList(); for (V value : values) { if (value != null) { list.add(value); @@ -188,8 +231,9 @@ public Collection values() { return list; } + @Nonnull public Set> entrySet() { - Set> set = new HashSet>(); + final Set> set = new HashSet>(); for (K key : keys) { if (key != null) { set.add(new MapEntry(this, key)); @@ -207,19 +251,23 @@ public MapEntry(Map map, K key) { this.key = key; } + @Override public K getKey() { return key; } + @Override public V getValue() { return map.get(key); } + @Override public V setValue(V value) { return map.put(key, value); } } + @Override public void writeExternal(ObjectOutput out) throws IOException { // remember the number of bits out.writeInt(this.bits); @@ -235,6 +283,7 @@ public void writeExternal(ObjectOutput out) throws IOException { } @SuppressWarnings("unchecked") + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { // resize to old bit size int bitSize = in.readInt(); @@ -250,19 +299,19 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept @Override public String toString() { - return this.getClass().getSimpleName() + ' ' + this.size; + return this.getClass().getSimpleName() + ' ' + size; } @SuppressWarnings("unchecked") - private void resize(int bits) { + private void resize(final int bits) { this.bits = bits; this.sweepbits = bits / 4; this.sweep = MathUtils.powerOf(2, sweepbits) * 4; - this.sweepmask = MathUtils.bitMask(bits - this.sweepbits) << sweepbits; + this.sweepmask = MathUtils.bitMask(bits - sweepbits) << sweepbits; // remember old values so we can recreate the entries - K[] existingKeys = this.keys; - V[] existingValues = this.values; + final K[] existingKeys = this.keys; + final V[] existingValues = this.values; // create the arrays this.values = (V[]) new Object[MathUtils.powerOf(2, bits) + sweep]; @@ -272,31 +321,38 @@ private void resize(int bits) { // re-add the previous entries if resizing if (existingKeys != null) { for (int x = 0; x < existingKeys.length; x++) { - if (existingKeys[x] != null) { - put(existingKeys[x], existingValues[x]); + final K k = existingKeys[x]; + if (k != null) { + put(k, existingValues[x]); } } } } - private int getBucketOffset(Object key) { - return (key.hashCode() << this.sweepbits) & this.sweepmask; + private int getBucketOffset(@Nonnull final Object key) { + return (key.hashCode() << sweepbits) & sweepmask; } - private static boolean compare(final Object v1, final Object v2) { + private static boolean compare(@Nonnull final Object v1, @Nonnull final Object v2) { return v1 == v2 || v1.equals(v2); } public IMapIterator entries() { - return new MapIterator(); + return new MapIterator(false); + } + + public IMapIterator entries(boolean releaseSeen) { + return new MapIterator(releaseSeen); } private final class MapIterator implements IMapIterator { + final boolean releaseSeen; int nextEntry; int lastEntry = -1; - MapIterator() { + MapIterator(boolean releaseSeen) { + this.releaseSeen = releaseSeen; this.nextEntry = nextEntry(0); } @@ -315,7 +371,9 @@ public boolean hasNext() { @Override public int next() { - free(lastEntry); + if (releaseSeen) { + free(lastEntry); + } if (!hasNext()) { return -1; } diff --git a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java index 7fec9b03a..4599bfc43 100644 --- a/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java +++ b/core/src/main/java/hivemall/utils/collections/maps/OpenHashTable.java @@ -27,16 +27,22 @@ import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; -import java.util.HashMap; import javax.annotation.Nonnull; /** - * An open-addressing hash table with double-hashing that requires less memory to {@link HashMap}. + * An open-addressing hash table using double-hashing. + * + *

+ * Primary hash function: h1(k) = k mod m
+ * Secondary hash function: h2(k) = 1 + (k mod(m-2))
+ * 
+ * + * @see http://en.wikipedia.org/wiki/Double_hashing */ public final class OpenHashTable implements Externalizable { - public static final float DEFAULT_LOAD_FACTOR = 0.7f; + public static final float DEFAULT_LOAD_FACTOR = 0.75f; public static final float DEFAULT_GROW_FACTOR = 2.0f; protected static final byte FREE = 0; diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 0b68de8e3..db56b8208 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -289,12 +289,21 @@ public static boolean isIntegerOI(@Nonnull final ObjectInspector argOI) { } } - @Nonnull public static boolean isListOI(@Nonnull final ObjectInspector oi) { Category category = oi.getCategory(); return category == Category.LIST; } + public static boolean isStringListOI(@Nonnull final ObjectInspector oi) + throws UDFArgumentException { + Category category = oi.getCategory(); + if (category != Category.LIST) { + throw new UDFArgumentException("Expected List OI but was: " + oi); + } + ListObjectInspector listOI = (ListObjectInspector) oi; + return isStringOI(listOI.getListElementObjectInspector()); + } + public static boolean isMapOI(@Nonnull final ObjectInspector oi) { return oi.getCategory() == Category.MAP; } @@ -669,6 +678,36 @@ public static long[] asLongArray(@Nullable final Object argObj, return ary; } + @Nullable + public static float[] asFloatArray(@Nullable final Object argObj, + @Nonnull final ListObjectInspector listOI, + @Nonnull final PrimitiveObjectInspector elemOI) throws UDFArgumentException { + return asFloatArray(argObj, listOI, elemOI, true); + } + + @Nullable + public static float[] asFloatArray(@Nullable final Object argObj, + @Nonnull final ListObjectInspector listOI, + @Nonnull final PrimitiveObjectInspector elemOI, final boolean avoidNull) + throws UDFArgumentException { + if (argObj == null) { + return null; + } + final int length = listOI.getListLength(argObj); + final float[] ary = new float[length]; + for (int i = 0; i < length; i++) { + Object o = listOI.getListElement(argObj, i); + if (o == null) { + if (avoidNull) { + continue; + } + throw new UDFArgumentException("Found null at index " + i); + } + ary[i] = PrimitiveObjectInspectorUtils.getFloat(o, elemOI); + } + return ary; + } + @Nullable public static double[] asDoubleArray(@Nullable final Object argObj, @Nonnull final ListObjectInspector listOI, @@ -694,8 +733,7 @@ public static double[] asDoubleArray(@Nullable final Object argObj, } throw new UDFArgumentException("Found null at index " + i); } - double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); - ary[i] = d; + ary[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); } return ary; } @@ -721,8 +759,7 @@ public static void toDoubleArray(@Nullable final Object argObj, } throw new UDFArgumentException("Found null at index " + i); } - double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); - out[i] = d; + out[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); } return; } @@ -746,8 +783,7 @@ public static void toDoubleArray(@Nullable final Object argObj, out[i] = nullValue; continue; } - double d = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); - out[i] = d; + out[i] = PrimitiveObjectInspectorUtils.getDouble(o, elemOI); } return; } @@ -766,11 +802,11 @@ public static int setBits(@Nullable final Object argObj, int count = 0; final int length = listOI.getListLength(argObj); for (int i = 0; i < length; i++) { - Object o = listOI.getListElement(argObj, i); + final Object o = listOI.getListElement(argObj, i); if (o == null) { continue; } - int index = PrimitiveObjectInspectorUtils.getInt(o, elemOI); + final int index = PrimitiveObjectInspectorUtils.getInt(o, elemOI); if (index < 0) { throw new UDFArgumentException("Negative index is not allowed: " + index); } @@ -954,6 +990,26 @@ public static PrimitiveObjectInspector asDoubleCompatibleOI(@Nonnull final Objec return oi; } + @Nonnull + public static PrimitiveObjectInspector asFloatingPointOI(@Nonnull final ObjectInspector argOI) + throws UDFArgumentTypeException { + if (argOI.getCategory() != Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + + argOI.getTypeName() + " is passed."); + } + final PrimitiveObjectInspector oi = (PrimitiveObjectInspector) argOI; + switch (oi.getPrimitiveCategory()) { + case FLOAT: + case DOUBLE: + break; + default: + throw new UDFArgumentTypeException(0, + "Only numeric or string type arguments are accepted but " + argOI.getTypeName() + + " is passed."); + } + return oi; + } + @Nonnull public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { diff --git a/core/src/main/java/hivemall/utils/hashing/HashUtils.java b/core/src/main/java/hivemall/utils/hashing/HashUtils.java new file mode 100644 index 000000000..710d8f618 --- /dev/null +++ b/core/src/main/java/hivemall/utils/hashing/HashUtils.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.hashing; + +public final class HashUtils { + + private HashUtils() {} + + public static int jenkins32(int k) { + k = (k + 0x7ed55d16) + (k << 12); + k = (k ^ 0xc761c23c) ^ (k >> 19); + k = (k + 0x165667b1) + (k << 5); + k = (k + 0xd3a2646c) ^ (k << 9); + k = (k + 0xfd7046c5) + (k << 3); + k = (k ^ 0xb55a4f09) ^ (k >> 16); + return k; + } + + public static int murmurHash3(int k) { + k ^= k >>> 16; + k *= 0x85ebca6b; + k ^= k >>> 13; + k *= 0xc2b2ae35; + k ^= k >>> 16; + return k; + } + + public static int fnv1a(final int k) { + int hash = 0x811c9dc5; + for (int i = 0; i < 4; i++) { + hash ^= k << (i * 8); + hash *= 0x01000193; + } + return hash; + } + + /** + * https://gist.github.com/badboy/6267743 + */ + public static int hash32shift(int k) { + k = ~k + (k << 15); // key = (key << 15) - key - 1; + k = k ^ (k >>> 12); + k = k + (k << 2); + k = k ^ (k >>> 4); + k = k * 2057; // key = (key + (key << 3)) + (key << 11); + k = k ^ (k >>> 16); + return k; + } + + public static int hash32shiftmult(int k) { + k = (k ^ 61) ^ (k >>> 16); + k = k + (k << 3); + k = k ^ (k >>> 4); + k = k * 0x27d4eb2d; + k = k ^ (k >>> 15); + return k; + } + + /** + * http://burtleburtle.net/bob/hash/integer.html + */ + public static int hash7shifts(int k) { + k -= (k << 6); + k ^= (k >> 17); + k -= (k << 9); + k ^= (k << 4); + k -= (k << 3); + k ^= (k << 10); + k ^= (k >> 15); + return k; + } + +} diff --git a/core/src/main/java/hivemall/utils/lang/NumberUtils.java b/core/src/main/java/hivemall/utils/lang/NumberUtils.java index 0d3f895e8..4b04f0443 100644 --- a/core/src/main/java/hivemall/utils/lang/NumberUtils.java +++ b/core/src/main/java/hivemall/utils/lang/NumberUtils.java @@ -107,4 +107,72 @@ public static boolean isDigits(String str) { return true; } + /** + * @throws ArithmeticException + */ + public static int castToInt(final long value) { + final int result = (int) value; + if (result != value) { + throw new ArithmeticException("Out of range: " + value); + } + return result; + } + + /** + * @throws ArithmeticException + */ + public static short castToShort(final int value) { + final short result = (short) value; + if (result != value) { + throw new ArithmeticException("Out of range: " + value); + } + return result; + } + + /** + * Cast Double to Float. + * + * @throws ArithmeticException + */ + public static float castToFloat(final double v) { + if ((v < Float.MIN_VALUE) || (v > Float.MAX_VALUE)) { + throw new ArithmeticException("Double value is out of Float range: " + v); + } + return (float) v; + } + + /** + * Cast Double to Float. + * + * @return v if v is Float range; Float.MIN_VALUE or Float.MAX_VALUE otherwise + */ + public static float safeCast(final double v) { + if (v < Float.MIN_VALUE) { + return Float.MIN_VALUE; + } else if (v > Float.MAX_VALUE) { + return Float.MAX_VALUE; + } + return (float) v; + } + + /** + * Cast Double to Float. + * + * @return v if v is Float range; defaultValue otherwise + */ + public static float safeCast(final double v, final float defaultValue) { + if ((v < Float.MIN_VALUE) || (v > Float.MAX_VALUE)) { + return defaultValue; + } + return (float) v; + } + + public static int toUnsignedShort(final short v) { + return v & 0xFFFF; // convert to range 0-65535 from -32768-32767. + } + + public static int toUnsignedInt(final byte x) { + return ((int) x) & 0xff; + } + } diff --git a/core/src/main/java/hivemall/utils/lang/Primitives.java b/core/src/main/java/hivemall/utils/lang/Primitives.java index 2ec012cb2..7d43da110 100644 --- a/core/src/main/java/hivemall/utils/lang/Primitives.java +++ b/core/src/main/java/hivemall/utils/lang/Primitives.java @@ -26,14 +26,6 @@ public final class Primitives { private Primitives() {} - public static int toUnsignedShort(final short v) { - return v & 0xFFFF; // convert to range 0-65535 from -32768-32767. - } - - public static int toUnsignedInt(final byte x) { - return ((int) x) & 0xff; - } - public static short parseShort(final String s, final short defaultValue) { if (s == null) { return defaultValue; @@ -92,22 +84,6 @@ public static void putChar(final byte[] b, final int off, final char val) { b[off] = (byte) (val >>> 8); } - public static int toIntExact(final long longValue) { - final int casted = (int) longValue; - if (casted != longValue) { - throw new ArithmeticException("integer overflow: " + longValue); - } - return casted; - } - - public static int castToInt(final long value) { - final int result = (int) value; - if (result != value) { - throw new IllegalArgumentException("Out of range: " + value); - } - return result; - } - public static long toLong(final int high, final int low) { return ((long) high << 32) | ((long) low & 0xffffffffL); } diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 3f41b6fc6..6162adb10 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -264,7 +264,7 @@ public static long floorDiv(final long x, final long y) { return r; } - public static boolean equals(@Nonnull final float value, final float expected, final float delta) { + public static boolean equals(final float value, final float expected, final float delta) { if (Double.isNaN(value)) { return false; } @@ -274,8 +274,7 @@ public static boolean equals(@Nonnull final float value, final float expected, f return true; } - public static boolean equals(@Nonnull final double value, final double expected, - final double delta) { + public static boolean equals(final double value, final double expected, final double delta) { if (Double.isNaN(value)) { return false; } @@ -285,26 +284,34 @@ public static boolean equals(@Nonnull final double value, final double expected, return true; } - public static boolean almostEquals(@Nonnull final float value, final float expected) { + public static boolean almostEquals(final float value, final float expected) { return equals(value, expected, 1E-15f); } - public static boolean almostEquals(@Nonnull final double value, final double expected) { + public static boolean almostEquals(final double value, final double expected) { return equals(value, expected, 1E-15d); } - public static boolean closeToZero(@Nonnull final float value) { - if (Math.abs(value) > 1E-15f) { - return false; + public static boolean closeToZero(final float value) { + return closeToZero(value, 1E-15f); + } + + public static boolean closeToZero(final float value, @Nonnegative final float tol) { + if (value == 0.f) { + return true; } - return true; + return Math.abs(value) <= tol; } - public static boolean closeToZero(@Nonnull final double value) { - if (Math.abs(value) > 1E-15d) { - return false; + public static boolean closeToZero(final double value) { + return closeToZero(value, 1E-15d); + } + + public static boolean closeToZero(final double value, @Nonnegative final double tol) { + if (value == 0.d) { + return true; } - return true; + return Math.abs(value) <= tol; } public static double sign(final double x) { diff --git a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java b/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java deleted file mode 100644 index 076387f7b..000000000 --- a/core/src/test/java/hivemall/fm/FFMPredictionModelTest.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.fm; - -import hivemall.utils.buffer.HeapBuffer; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; - -import java.io.IOException; - -import org.junit.Assert; -import org.junit.Test; - -public class FFMPredictionModelTest { - - @Test - public void testSerialize() throws IOException, ClassNotFoundException { - final int factors = 3; - final int entrySize = Entry.sizeOf(factors); - - HeapBuffer buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE); - Int2LongOpenHashTable map = Int2LongOpenHashTable.newInstance(); - - Entry e1 = new Entry(buf, factors, buf.allocate(entrySize)); - e1.setW(1f); - e1.setV(new float[] {1f, -1f, -1f}); - - Entry e2 = new Entry(buf, factors, buf.allocate(entrySize)); - e2.setW(2f); - e2.setV(new float[] {1f, 2f, -1f}); - - Entry e3 = new Entry(buf, factors, buf.allocate(entrySize)); - e3.setW(3f); - e3.setV(new float[] {1f, 2f, 3f}); - - map.put(1, e1.getOffset()); - map.put(2, e2.getOffset()); - map.put(3, e3.getOffset()); - - FFMPredictionModel expected = new FFMPredictionModel(map, buf, 0.d, 3, - Feature.DEFAULT_NUM_FEATURES, Feature.DEFAULT_NUM_FIELDS); - byte[] b = expected.serialize(); - - FFMPredictionModel actual = FFMPredictionModel.deserialize(b, b.length); - Assert.assertEquals(3, actual.getNumFactors()); - Assert.assertEquals(Feature.DEFAULT_NUM_FEATURES, actual.getNumFeatures()); - Assert.assertEquals(Feature.DEFAULT_NUM_FIELDS, actual.getNumFields()); - } - -} diff --git a/core/src/test/java/hivemall/fm/FeatureTest.java b/core/src/test/java/hivemall/fm/FeatureTest.java index 25e56716c..911a4a589 100644 --- a/core/src/test/java/hivemall/fm/FeatureTest.java +++ b/core/src/test/java/hivemall/fm/FeatureTest.java @@ -34,7 +34,7 @@ public void testParseFeature() throws HiveException { @Test public void testParseFFMFeature() throws HiveException { - IntFeature f1 = Feature.parseFFMFeature("2:1163:0.3651"); + IntFeature f1 = Feature.parseFFMFeature("2:1163:0.3651", -1); Assert.assertEquals(2, f1.getField()); Assert.assertEquals(1163, f1.getFeatureIndex()); Assert.assertEquals("1163", f1.getFeature()); @@ -85,4 +85,9 @@ public void testParseIntFeatureFails() throws HiveException { Feature.parseFeature("2:1163:0.3651", true); } + @Test(expected = HiveException.class) + public void testParseFeatureZeroIndex() throws HiveException { + Feature.parseFFMFeature("0:0.3652"); + } + } diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java index 792ede113..3b219c699 100644 --- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java +++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java @@ -23,11 +23,11 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.util.ArrayList; +import java.util.List; import java.util.zip.GZIPInputStream; import javax.annotation.Nonnull; -import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -44,32 +44,29 @@ public class FieldAwareFactorizationMachineUDTFTest { @Test public void testSGD() throws HiveException, IOException { - runTest("Pure SGD test", - "-classification -factors 10 -w0 -seed 43 -disable_adagrad -disable_ftrl", 0.60f); + runTest("Pure SGD test", "-opt sgd -classification -factors 10 -w0 -seed 43", 0.60f); } @Test - public void testSGDWithFTRL() throws HiveException, IOException { - runTest("SGD w/ FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_adagrad", - 0.60f); + public void testAdaGrad() throws HiveException, IOException { + runTest("AdaGrad test", "-opt adagrad -classification -factors 10 -w0 -seed 43", 0.30f); } @Test public void testAdaGradNoCoeff() throws HiveException, IOException { - runTest("AdaGrad No Coeff test", "-classification -factors 10 -w0 -seed 43 -no_coeff", - 0.30f); + runTest("AdaGrad No Coeff test", + "-opt adagrad -no_coeff -classification -factors 10 -w0 -seed 43", 0.30f); } @Test - public void testAdaGradNoFTRL() throws HiveException, IOException { - runTest("AdaGrad w/o FTRL test", "-classification -factors 10 -w0 -seed 43 -disable_ftrl", - 0.30f); + public void testFTRL() throws HiveException, IOException { + runTest("FTRL test", "-opt ftrl -classification -factors 10 -w0 -seed 43", 0.30f); } @Test - public void testAdaGradDefault() throws HiveException, IOException { - runTest("AdaGrad DEFAULT (adagrad for V + FTRL for W)", - "-classification -factors 10 -w0 -seed 43", 0.30f); + public void testFTRLNoCoeff() throws HiveException, IOException { + runTest("FTRL Coeff test", "-opt ftrl -no_coeff -classification -factors 10 -w0 -seed 43", + 0.30f); } private static void runTest(String testName, String testOptions, float lossThreshold) @@ -100,30 +97,22 @@ private static void runTest(String testName, String testOptions, float lossThres if (input == null) { break; } - ArrayList featureStrings = new ArrayList(); - ArrayList features = new ArrayList(); - - //make StringFeature for each word = data point - String remaining = input; - int wordCut = remaining.indexOf(' '); - while (wordCut != -1) { - featureStrings.add(remaining.substring(0, wordCut)); - remaining = remaining.substring(wordCut + 1); - wordCut = remaining.indexOf(' '); - } - int end = featureStrings.size(); - double y = Double.parseDouble(featureStrings.get(0)); + String[] featureStrings = input.split(" "); + + double y = Double.parseDouble(featureStrings[0]); if (y == 0) { y = -1;//LibFFM data uses {0, 1}; Hivemall uses {-1, 1} } - for (int wordNumber = 1; wordNumber < end; ++wordNumber) { - String entireFeature = featureStrings.get(wordNumber); - int featureCut = StringUtils.ordinalIndexOf(entireFeature, ":", 2); - String feature = entireFeature.substring(0, featureCut); - double value = Double.parseDouble(entireFeature.substring(featureCut + 1)); - features.add(new StringFeature(feature, value)); + + final List features = new ArrayList(featureStrings.length - 1); + for (int j = 1; j < featureStrings.length; ++j) { + String[] splitted = featureStrings[j].split(":"); + Assert.assertEquals(3, splitted.length); + int index = Integer.parseInt(splitted[1]) + 1; + String f = splitted[0] + ':' + index + ':' + splitted[2]; + features.add(f); } - udtf.process(new Object[] {toStringArray(features), y}); + udtf.process(new Object[] {features, y}); } cumul = udtf._cvState.getCumulativeLoss(); loss = (cumul - loss) / lines; @@ -143,15 +132,6 @@ private static BufferedReader readFile(@Nonnull String fileName) throws IOExcept return new BufferedReader(new InputStreamReader(is)); } - private static String[] toStringArray(ArrayList x) { - final int size = x.size(); - final String[] ret = new String[size]; - for (int i = 0; i < size; i++) { - ret[i] = x.get(i).toString(); - } - return ret; - } - private static void println(String line) { if (DEBUG) { System.out.println(line); diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java index bf2ac1184..f88504191 100644 --- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java +++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java @@ -52,6 +52,7 @@ import smile.validation.LOOCV; import smile.validation.RMSE; +@SuppressWarnings("deprecation") public class TreePredictUDFv1Test { private static final boolean DEBUG = false; diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashTableTest.java similarity index 98% rename from core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java rename to core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashTableTest.java index 6a2ff96c6..53814acda 100644 --- a/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashMapTest.java +++ b/core/src/test/java/hivemall/utils/collections/maps/Int2FloatOpenHashTableTest.java @@ -23,7 +23,7 @@ import org.junit.Assert; import org.junit.Test; -public class Int2FloatOpenHashMapTest { +public class Int2FloatOpenHashTableTest { @Test public void testSize() { diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java index 7951b0b3b..ee36a83e0 100644 --- a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java +++ b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashMapTest.java @@ -18,11 +18,6 @@ */ package hivemall.utils.collections.maps; -import hivemall.utils.collections.maps.Int2LongOpenHashTable; -import hivemall.utils.lang.ObjectUtils; - -import java.io.IOException; - import org.junit.Assert; import org.junit.Test; @@ -30,7 +25,7 @@ public class Int2LongOpenHashMapTest { @Test public void testSize() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); map.put(1, 3L); Assert.assertEquals(3L, map.get(1)); map.put(1, 5L); @@ -40,67 +35,72 @@ public void testSize() { @Test public void testDefaultReturnValue() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); Assert.assertEquals(0, map.size()); - Assert.assertEquals(-1L, map.get(1)); - long ret = Long.MIN_VALUE; - map.defaultReturnValue(ret); - Assert.assertEquals(ret, map.get(1)); + Assert.assertEquals(0L, map.get(1)); + Assert.assertEquals(Long.MIN_VALUE, map.get(1, Long.MIN_VALUE)); } @Test public void testPutAndGet() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); final int numEntries = 1000000; for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); + Assert.assertEquals(0L, map.put(i, i)); + Assert.assertEquals(0L, map.put(-i, -i)); } - Assert.assertEquals(numEntries, map.size()); + Assert.assertEquals(numEntries * 2 - 1, map.size()); for (int i = 0; i < numEntries; i++) { - long v = map.get(i); - Assert.assertEquals(i, v); + Assert.assertEquals(i, map.get(i)); + Assert.assertEquals(-i, map.get(-i)); } } @Test - public void testSerde() throws IOException, ClassNotFoundException { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + public void testPutRemoveGet() { + Int2LongOpenHashMap map = new Int2LongOpenHashMap(16384); final int numEntries = 1000000; for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); + Assert.assertEquals(0L, map.put(i, i)); + Assert.assertEquals(0L, map.put(-i, -i)); + if (i % 2 == 0) { + Assert.assertEquals(i, map.remove(i, -1)); + } else { + Assert.assertEquals(i, map.put(i, i)); + } } - - byte[] b = ObjectUtils.toCompressedBytes(map); - map = new Int2LongOpenHashTable(16384); - ObjectUtils.readCompressedObject(b, map); - - Assert.assertEquals(numEntries, map.size()); + Assert.assertEquals(numEntries + (numEntries / 2) - 1, map.size()); for (int i = 0; i < numEntries; i++) { - long v = map.get(i); - Assert.assertEquals(i, v); + if (i % 2 == 0) { + Assert.assertFalse(map.containsKey(i)); + } else { + Assert.assertEquals(i, map.get(i)); + } + Assert.assertEquals(-i, map.get(-i)); } } @Test public void testIterator() { - Int2LongOpenHashTable map = new Int2LongOpenHashTable(1000); - Int2LongOpenHashTable.IMapIterator itor = map.entries(); + Int2LongOpenHashMap map = new Int2LongOpenHashMap(1000); + Int2LongOpenHashMap.MapIterator itor = map.entries(); Assert.assertFalse(itor.hasNext()); final int numEntries = 1000000; for (int i = 0; i < numEntries; i++) { - Assert.assertEquals(-1L, map.put(i, i)); + Assert.assertEquals(0L, map.put(i, i)); + Assert.assertEquals(0L, map.put(-i, -i)); } - Assert.assertEquals(numEntries, map.size()); + Assert.assertEquals(numEntries * 2 - 1, map.size()); itor = map.entries(); Assert.assertTrue(itor.hasNext()); while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); + Assert.assertTrue(itor.next()); int k = itor.getKey(); long v = itor.getValue(); Assert.assertEquals(k, v); } - Assert.assertEquals(-1, itor.next()); + Assert.assertFalse(itor.next()); } } diff --git a/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java new file mode 100644 index 000000000..c2ce132bb --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/maps/Int2LongOpenHashTableTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.lang.ObjectUtils; + +import java.io.IOException; + +import org.junit.Assert; +import org.junit.Test; + +public class Int2LongOpenHashTableTest { + + @Test + public void testSize() { + Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + map.put(1, 3L); + Assert.assertEquals(3L, map.get(1)); + map.put(1, 5L); + Assert.assertEquals(5L, map.get(1)); + Assert.assertEquals(1, map.size()); + } + + @Test + public void testDefaultReturnValue() { + Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + Assert.assertEquals(0, map.size()); + Assert.assertEquals(-1L, map.get(1)); + long ret = Long.MIN_VALUE; + map.defaultReturnValue(ret); + Assert.assertEquals(ret, map.get(1)); + } + + @Test + public void testPutAndGet() { + Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1L, map.put(i, i)); + } + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + long v = map.get(i); + Assert.assertEquals(i, v); + } + } + + @Test + public void testPutRemoveGet() { + Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + map.defaultReturnValue(0L); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(0L, map.put(i, i)); + Assert.assertEquals(0L, map.put(-i, -i)); + if (i % 2 == 0) { + Assert.assertEquals(i, map.remove(i)); + } else { + Assert.assertEquals(i, map.put(i, i)); + } + } + Assert.assertEquals(numEntries + (numEntries / 2) - 1, map.size()); + for (int i = 0; i < numEntries; i++) { + if (i % 2 == 0) { + Assert.assertFalse(map.containsKey(i)); + } else { + Assert.assertEquals(i, map.get(i)); + } + Assert.assertEquals(-i, map.get(-i)); + } + } + + @Test + public void testSerde() throws IOException, ClassNotFoundException { + Int2LongOpenHashTable map = new Int2LongOpenHashTable(16384); + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1L, map.put(i, i)); + } + + byte[] b = ObjectUtils.toCompressedBytes(map); + map = new Int2LongOpenHashTable(16384); + ObjectUtils.readCompressedObject(b, map); + + Assert.assertEquals(numEntries, map.size()); + for (int i = 0; i < numEntries; i++) { + long v = map.get(i); + Assert.assertEquals(i, v); + } + } + + @Test + public void testIterator() { + Int2LongOpenHashTable map = new Int2LongOpenHashTable(1000); + Int2LongOpenHashTable.MapIterator itor = map.entries(); + Assert.assertFalse(itor.hasNext()); + + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertEquals(-1L, map.put(i, i)); + } + Assert.assertEquals(numEntries, map.size()); + + itor = map.entries(); + Assert.assertTrue(itor.hasNext()); + while (itor.hasNext()) { + Assert.assertFalse(itor.next() == -1); + int k = itor.getKey(); + long v = itor.getValue(); + Assert.assertEquals(k, v); + } + Assert.assertEquals(-1, itor.next()); + } +} diff --git a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java deleted file mode 100644 index 675c586b2..000000000 --- a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashMapTest.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.utils.collections.maps; - -import hivemall.utils.collections.maps.IntOpenHashMap; - -import org.junit.Assert; -import org.junit.Test; - -public class IntOpenHashMapTest { - - @Test - public void testSize() { - IntOpenHashMap map = new IntOpenHashMap(16384); - map.put(1, Float.valueOf(3.f)); - Assert.assertEquals(Float.valueOf(3.f), map.get(1)); - map.put(1, Float.valueOf(5.f)); - Assert.assertEquals(Float.valueOf(5.f), map.get(1)); - Assert.assertEquals(1, map.size()); - } - - @Test - public void testPutAndGet() { - IntOpenHashMap map = new IntOpenHashMap(16384); - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertNull(map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - for (int i = 0; i < numEntries; i++) { - Integer v = map.get(i); - Assert.assertEquals(i, v.intValue()); - } - } - - @Test - public void testIterator() { - IntOpenHashMap map = new IntOpenHashMap(1000); - IntOpenHashMap.IMapIterator itor = map.entries(); - Assert.assertFalse(itor.hasNext()); - - final int numEntries = 1000000; - for (int i = 0; i < numEntries; i++) { - Assert.assertNull(map.put(i, i)); - } - Assert.assertEquals(numEntries, map.size()); - - itor = map.entries(); - Assert.assertTrue(itor.hasNext()); - while (itor.hasNext()) { - Assert.assertFalse(itor.next() == -1); - int k = itor.getKey(); - Integer v = itor.getValue(); - Assert.assertEquals(k, v.intValue()); - } - Assert.assertEquals(-1, itor.next()); - } - -} diff --git a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java index d5887cd88..46a393816 100644 --- a/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java +++ b/core/src/test/java/hivemall/utils/collections/maps/IntOpenHashTableTest.java @@ -49,4 +49,27 @@ public void testPutAndGet() { } } + @Test + public void testIterator() { + IntOpenHashTable map = new IntOpenHashTable(1000); + IntOpenHashTable.IMapIterator itor = map.entries(); + Assert.assertFalse(itor.hasNext()); + + final int numEntries = 1000000; + for (int i = 0; i < numEntries; i++) { + Assert.assertNull(map.put(i, i)); + } + Assert.assertEquals(numEntries, map.size()); + + itor = map.entries(); + Assert.assertTrue(itor.hasNext()); + while (itor.hasNext()) { + Assert.assertFalse(itor.next() == -1); + int k = itor.getKey(); + Integer v = itor.getValue(); + Assert.assertEquals(k, v.intValue()); + } + Assert.assertEquals(-1, itor.next()); + } + } diff --git a/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java similarity index 98% rename from core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashMapTest.java rename to core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java index a03af538b..ca4338385 100644 --- a/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashMapTest.java +++ b/core/src/test/java/hivemall/utils/collections/maps/Long2IntOpenHashTableTest.java @@ -26,7 +26,7 @@ import org.junit.Assert; import org.junit.Test; -public class Long2IntOpenHashMapTest { +public class Long2IntOpenHashTableTest { @Test public void testSize() { diff --git a/docs/gitbook/getting_started/input-format.md b/docs/gitbook/getting_started/input-format.md index 7bd8573c6..a01b5e37b 100644 --- a/docs/gitbook/getting_started/input-format.md +++ b/docs/gitbook/getting_started/input-format.md @@ -190,25 +190,48 @@ from ## Quantitative Features -`array quantitative_features(array featureNames, ...)` is a helper function to create sparse quantitative features from a table. +`array quantitative_features(array featureNames, feature1, feature2, .. [, const string options])` is a helper function to create sparse quantitative features from a table. ```sql -select quantitative_features(array("apple","value"),1,120.3); +select quantitative_features( + array("apple","height","weight"), + 1,180.3,70.2 + -- ,"-emit_null" +); +``` +> ["apple:1.0","height:180.3","weight:70.2"] + +```sql +select quantitative_features( + array("apple","height","weight"), + 1,cast(null as double),70.2 + ,"-emit_null" +); ``` -> ["apple:1.0","value:120.3"] +> ["apple:1.0",null,"weight:70.2"] ## Categorical Features -`array categorical_features(array featureNames, ...)` is a helper function to create sparse categorical features from a table. +`array categorical_features(array featureNames, feature1, feature2, .. [, const string options])` is a helper function to create sparse categorical features from a table. ```sql select categorical_features( array("is_cat","is_dog","is_lion","is_pengin","species"), 1, 0, 1.0, true, "dog" + -- ,"-emit_null" ); ``` > ["is_cat#1","is_dog#0","is_lion#1.0","is_pengin#true","species#dog"] +```sql +select categorical_features( + array("is_cat","is_dog","is_lion","is_pengin","species"), + 1, 0, 1.0, true, null + ,"-emit_null" +); +``` +> ["is_cat#1","is_dog#0","is_lion#1.0","is_pengin#true",null] + ## Preparing training data table You can create a training data table as follows: diff --git a/pom.xml b/pom.xml index 7d14c6412..49005ebcc 100644 --- a/pom.xml +++ b/pom.xml @@ -287,6 +287,24 @@ 2.0 + + java7 + + -ea -Xms768m -Xmx1024m -XX:PermSize=128m -XX:MaxPermSize=512m -XX:ReservedCodeCacheSize=512m + + + [,1.8) + + + + java8 + + -ea -Xms768m -Xmx1024m -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=512m -XX:ReservedCodeCacheSize=512m + + + [1.8,) + + compile-xgboost diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index feb1a08e4..c2b38fbf4 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -313,6 +313,9 @@ CREATE FUNCTION binarize_label as 'hivemall.ftvec.trans.BinarizeLabelUDTF' USING DROP FUNCTION IF EXISTS onehot_encoding; CREATE FUNCTION onehot_encoding as 'hivemall.ftvec.trans.OnehotEncodingUDAF' USING JAR '${hivemall_jar}'; +DROP FUNCTION IF EXISTS add_field_indicies; +CREATE FUNCTION add_field_indicies as 'hivemall.ftvec.trans.AddFieldIndicesUDF' USING JAR '${hivemall_jar}'; + ------------------------------ -- ranking helper functions -- ------------------------------ @@ -620,7 +623,7 @@ DROP FUNCTION IF EXISTS train_ffm; CREATE FUNCTION train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF' USING JAR '${hivemall_jar}'; DROP FUNCTION IF EXISTS ffm_predict; -CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictUDF' USING JAR '${hivemall_jar}'; +CREATE FUNCTION ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF' USING JAR '${hivemall_jar}'; --------------------------- -- Anomaly Detection ------ diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 310f9f499..89821f802 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -309,6 +309,9 @@ create temporary function binarize_label as 'hivemall.ftvec.trans.BinarizeLabelU drop temporary function if exists onehot_encoding; create temporary function onehot_encoding as 'hivemall.ftvec.trans.OnehotEncodingUDAF'; +drop temporary function if exists add_field_indicies; +create temporary function add_field_indicies as 'hivemall.ftvec.trans.AddFieldIndicesUDF'; + ------------------------------ -- ranking helper functions -- ------------------------------ @@ -612,7 +615,7 @@ drop temporary function if exists train_ffm; create temporary function train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF'; drop temporary function if exists ffm_predict; -create temporary function ffm_predict as 'hivemall.fm.FFMPredictUDF'; +create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF'; --------------------------- -- Anomaly Detection ------ diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 42b235ba3..b4926e361 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -312,6 +312,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION binarize_label AS 'hivemall.ftvec.tran sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS onehot_encoding") sqlContext.sql("CREATE TEMPORARY FUNCTION onehot_encoding AS 'hivemall.ftvec.trans.OnehotEncodingUDAF'") +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS add_field_indicies") +sqlContext.sql("CREATE TEMPORARY FUNCTION add_field_indicies AS 'hivemall.ftvec.trans.AddFieldIndicesUDF'") + /** * ranking helper functions */ @@ -596,7 +599,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_ffm") sqlContext.sql("CREATE TEMPORARY FUNCTION train_ffm AS 'hivemall.fm.FieldAwareFactorizationMachineUDTF'") sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS ffm_predict") -sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 'hivemall.fm.FFMPredictUDF'") +sqlContext.sql("CREATE TEMPORARY FUNCTION ffm_predict AS 'hivemall.fm.FFMPredictGenericUDAF'") /** * Anomaly Detection diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index dd694e317..c7fdd4980 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -174,6 +174,9 @@ create temporary function dimsum_mapper as 'hivemall.knn.similarity.DIMSUMMapper create temporary function train_classifier as 'hivemall.classifier.GeneralClassifierUDTF'; create temporary function train_regressor as 'hivemall.regression.GeneralRegressorUDTF'; create temporary function tree_export as 'hivemall.smile.tools.TreeExportUDF'; +create temporary function train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF'; +create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF'; +create temporary function add_field_indicies as 'hivemall.ftvec.trans.AddFieldIndicesUDF'; -- NLP features create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF'; diff --git a/spark/spark-2.0/pom.xml b/spark/spark-2.0/pom.xml index 123c42486..74e934829 100644 --- a/spark/spark-2.0/pom.xml +++ b/spark/spark-2.0/pom.xml @@ -32,9 +32,6 @@ jar - 64m - 512m - 512m ${project.parent.basedir} @@ -164,11 +161,8 @@ - -Xms1024m + -Xms512m -Xmx1024m - -XX:PermSize=${PermGen} - -XX:MaxPermSize=${MaxPermGen} - -XX:ReservedCodeCacheSize=${CodeCacheSize} @@ -233,7 +227,7 @@ ${project.build.directory}/surefire-reports . SparkTestSuite.txt - -ea -Xmx2g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize} + ${spark.test.jvm.opts} 1 diff --git a/spark/spark-2.1/pom.xml b/spark/spark-2.1/pom.xml index 22d3e1242..d7ab81ada 100644 --- a/spark/spark-2.1/pom.xml +++ b/spark/spark-2.1/pom.xml @@ -32,9 +32,6 @@ jar - 64m - 512m - 512m ${project.parent.basedir} @@ -164,11 +161,8 @@ - -Xms1024m + -Xms512m -Xmx1024m - -XX:PermSize=${PermGen} - -XX:MaxPermSize=${MaxPermGen} - -XX:ReservedCodeCacheSize=${CodeCacheSize} @@ -233,7 +227,7 @@ ${project.build.directory}/surefire-reports . SparkTestSuite.txt - -ea -Xmx2g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=${CodeCacheSize} + ${spark.test.jvm.opts} 1 diff --git a/spark/spark-common/pom.xml b/spark/spark-common/pom.xml index e8e8ff4e0..3153a7544 100644 --- a/spark/spark-common/pom.xml +++ b/spark/spark-common/pom.xml @@ -32,9 +32,6 @@ jar - 64m - 1024m - 512m ${project.parent.basedir} @@ -138,11 +135,8 @@ - -Xms1024m + -Xms512m -Xmx1024m - -XX:PermSize=${PermGen} - -XX:MaxPermSize=${MaxPermGen} - -XX:ReservedCodeCacheSize=${CodeCacheSize}