From 5b1af8fea4659ec853b9eceebde90899e734071a Mon Sep 17 00:00:00 2001 From: Takuya Kitazawa Date: Thu, 6 Jul 2017 18:53:52 +0900 Subject: [PATCH 01/13] Close #87: [HIVEMALL-108] Support '-iter' option in generic predictors --- .../java/hivemall/GeneralLearnerBaseUDTF.java | 481 +++++++++++++++--- .../main/java/hivemall/UDTFWithOptions.java | 2 +- .../classifier/GeneralClassifierUDTF.java | 3 +- .../java/hivemall/common/ConversionState.java | 51 +- .../hivemall/fm/FactorizationMachineUDTF.java | 18 +- .../mf/BPRMatrixFactorizationUDTF.java | 20 +- .../mf/OnlineMatrixFactorizationUDTF.java | 16 +- .../java/hivemall/model/FeatureValue.java | 3 +- .../java/hivemall/optimizer/Optimizer.java | 2 +- .../regression/GeneralRegressionUDTF.java | 4 +- .../classifier/GeneralClassifierUDTFTest.java | 141 ++++- .../regression/GeneralRegressionUDTFTest.java | 149 +++++- 12 files changed, 712 insertions(+), 178 deletions(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 34c7ec970..f8d56be8f 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -19,6 +19,7 @@ package hivemall; import hivemall.annotations.VisibleForTesting; +import hivemall.common.ConversionState; import hivemall.model.FeatureValue; import hivemall.model.IWeightValue; import hivemall.model.PredictionModel; @@ -31,13 +32,23 @@ import hivemall.optimizer.OptimizerOptions; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NIOUtils; +import hivemall.utils.io.NioStatefullSegment; import hivemall.utils.lang.FloatAccumulator; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -55,30 +66,64 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaStringObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaLongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector; import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); + public enum FeatureType { + JavaString, Text, JavaInteger, WritableInt, JavaLong, WritableLong + } + private ListObjectInspector featureListOI; private PrimitiveObjectInspector featureInputOI; private PrimitiveObjectInspector targetOI; private boolean parseFeature; + private FeatureType featureType; + + // ----------------------------------------- + // hyperparameters @Nonnull private final Map optimizerOptions; private Optimizer optimizer; private LossFunction lossFunction; + // ----------------------------------------- + private PredictionModel model; private long count; - // The accumulated delta of each weight values. + // ----------------------------------------- + // for mini-batch + + /** The accumulated delta of each weight values. */ @Nullable private transient Map accumulated; private int sampled; - private double cumLoss; + // ----------------------------------------- + // for iterations + + @Nullable + protected transient NioStatefullSegment fileIO; + @Nullable + protected transient ByteBuffer inputBuf; + private int iterations; + protected ConversionState cvState; + + // ----------------------------------------- public GeneralLearnerBaseUDTF() { this(true); @@ -128,7 +173,6 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu this.count = 0L; this.sampled = 0; - this.cumLoss = 0.d; return getReturnOI(featureOutputOI); } @@ -137,6 +181,12 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu protected Options getOptions() { Options opts = super.getOptions(); opts.addOption("loss", "loss_function", true, getLossOptionDescription()); + opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); + // conversion check + opts.addOption("disable_cv", "disable_cvtest", false, + "Whether to disable convergence check [default: OFF]"); + opts.addOption("cv_rate", "convergence_rate", true, + "Threshold to determine convergence [default: 0.005]"); OptimizerOptions.setup(opts); return opts; } @@ -146,15 +196,33 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen CommandLine cl = super.processOptions(argOIs); LossFunction lossFunction = LossFunctions.getLossFunction(getDefaultLossType()); - if (cl.hasOption("loss_function")) { - try { - lossFunction = LossFunctions.getLossFunction(cl.getOptionValue("loss_function")); - } catch (Throwable e) { - throw new UDFArgumentException(e.getMessage()); + int iterations = 10; + boolean conversionCheck = true; + double convergenceRate = 0.005d; + + if (cl != null) { + if (cl.hasOption("loss_function")) { + try { + lossFunction = LossFunctions.getLossFunction(cl.getOptionValue("loss_function")); + } catch (Throwable e) { + throw new UDFArgumentException(e.getMessage()); + } } + checkLossFunction(lossFunction); + + iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations); + if (iterations < 1) { + throw new UDFArgumentException( + "'-iterations' must be greater than or equals to 1: " + iterations); + } + + conversionCheck = !cl.hasOption("disable_cvtest"); + convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); } - checkLossFunction(lossFunction); + this.lossFunction = lossFunction; + this.iterations = iterations; + this.cvState = new ConversionState(conversionCheck, convergenceRate); OptimizerOptions.propcessOptions(cl, optimizerOptions); @@ -167,6 +235,23 @@ protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector ar this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); HiveUtils.validateFeatureOI(featureRawOI); + if (featureRawOI instanceof JavaStringObjectInspector) { + this.featureType = FeatureType.JavaString; + } else if (featureRawOI instanceof WritableStringObjectInspector) { + this.featureType = FeatureType.Text; + } else if (featureRawOI instanceof JavaIntObjectInspector) { + this.featureType = FeatureType.JavaInteger; + } else if (featureRawOI instanceof WritableIntObjectInspector) { + this.featureType = FeatureType.WritableInt; + } else if (featureRawOI instanceof JavaLongObjectInspector) { + this.featureType = FeatureType.JavaLong; + } else if (featureRawOI instanceof WritableLongObjectInspector) { + this.featureType = FeatureType.WritableLong; + } else { + throw new UDFArgumentException("Feature object inspector must be one of " + + "[JavaString, WritableString, JavaInt, WritableInt, Long, WritableLong]: " + + featureRawOI.toString()); + } this.parseFeature = HiveUtils.isStringOI(featureRawOI); return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } @@ -204,8 +289,99 @@ public void process(Object[] args) throws HiveException { checkTargetValue(target); count++; - train(featureVector, target); + + recordTrainSampleToTempFile(featureVector, target); + } + + protected void recordTrainSampleToTempFile(@Nonnull final FeatureValue[] featureVector, + final float target) throws HiveException { + if (iterations == 1) { + return; + } + + ByteBuffer buf = inputBuf; + NioStatefullSegment dst = fileIO; + + if (buf == null) { + final File file; + try { + file = File.createTempFile("hivemall_general_learner", ".sgmt"); + file.deleteOnExit(); + if (!file.canWrite()) { + throw new UDFArgumentException("Cannot write a temporary file: " + + file.getAbsolutePath()); + } + logger.info("Record training samples to a file: " + file.getAbsolutePath()); + } catch (IOException ioe) { + throw new UDFArgumentException(ioe); + } catch (Throwable e) { + throw new UDFArgumentException(e); + } + this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB + this.fileIO = dst = new NioStatefullSegment(file, false); + } + + int featureVectorBytes = 0; + for (FeatureValue f : featureVector) { + if (f == null) { + continue; + } + String feature = f.getFeatureAsString(); + + // feature as String (even if it is Text or Integer) + featureVectorBytes += SizeOf.CHAR * feature.length(); + + // NIOUtils.putString() first puts the length of string before string itself + featureVectorBytes += SizeOf.INT; + + // value + featureVectorBytes += SizeOf.DOUBLE; + } + + // feature length, feature 1, feature 2, ..., feature n, target + int recordBytes = SizeOf.INT + featureVectorBytes + SizeOf.FLOAT; + int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself + + int remain = buf.remaining(); + if (remain < requiredBytes) { + writeBuffer(buf, dst); + } + + buf.putInt(recordBytes); + buf.putInt(featureVector.length); + for (FeatureValue f : featureVector) { + writeFeatureValue(buf, f); + } + buf.putFloat(target); + } + + private void writeFeatureValue(@Nonnull ByteBuffer buf, @Nonnull FeatureValue f) { + NIOUtils.putString(f.getFeatureAsString(), buf); + buf.putDouble(f.getValue()); + } + + private FeatureValue readFeatureValue(@Nonnull ByteBuffer buf) { + Object feature = NIOUtils.getString(buf); + switch (featureType) { + case Text: + feature = new Text((String) feature); + break; + case JavaInteger: + feature = Integer.parseInt((String) feature); + break; + case WritableInt: + feature = new IntWritable(Integer.parseInt((String) feature)); + break; + case JavaLong: + feature = Long.parseLong((String) feature); + break; + case WritableLong: + feature = new LongWritable(Long.parseLong((String) feature)); + break; + } + double value = buf.getDouble(); + return new FeatureValue(feature, value); } @Nullable @@ -224,7 +400,11 @@ public final FeatureValue[] parseFeatures(@Nonnull final List features) { } final FeatureValue fv; if (parseFeature) { - fv = FeatureValue.parse(f); + if (featureType == FeatureType.JavaString) { + fv = FeatureValue.parseFeatureAsString((String) f); + } else { + fv = FeatureValue.parse(f); // = parse feature as Text + } } else { Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); fv = new FeatureValue(k, 1.f); @@ -234,6 +414,17 @@ public final FeatureValue[] parseFeatures(@Nonnull final List features) { return featureVector; } + private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) + throws HiveException { + srcBuf.flip(); + try { + dst.write(srcBuf); + } catch (IOException e) { + throw new HiveException("Exception causes while writing a buffer to file", e); + } + srcBuf.clear(); + } + public float predict(@Nonnull final FeatureValue[] features) { float score = 0.f; for (FeatureValue f : features) {// a += w[i] * x[i] @@ -253,8 +444,10 @@ public float predict(@Nonnull final FeatureValue[] features) { protected void update(@Nonnull final FeatureValue[] features, final float target, final float predicted) { - this.cumLoss += lossFunction.loss(predicted, target); // retain cumulative loss to check convergence - float dloss = lossFunction.dloss(predicted, target); + float loss = lossFunction.loss(predicted, target); + cvState.incrLoss(loss); // retain cumulative loss to check convergence + + final float dloss = lossFunction.dloss(predicted, target); if (is_mini_batch) { accumulateUpdate(features, dloss); @@ -318,66 +511,228 @@ protected void onlineUpdate(@Nonnull final FeatureValue[] features, final float @Override public final void close() throws HiveException { super.close(); - if (model != null) { - if (accumulated != null) { // Update model with accumulated delta - batchUpdate(); - this.accumulated = null; - } - int numForwarded = 0; - if (useCovariance()) { - final WeightValueWithCovar probe = new WeightValueWithCovar(); - final Object[] forwardMapObj = new Object[3]; - final FloatWritable fv = new FloatWritable(); - final FloatWritable cov = new FloatWritable(); - final IMapIterator itor = model.entries(); - while (itor.next() != -1) { - itor.getValue(probe); - if (!probe.isTouched()) { - continue; // skip outputting untouched weights + finalizeTraining(); + forwardModel(); + this.accumulated = null; + this.model = null; + } + + @VisibleForTesting + public void finalizeTraining() throws HiveException { + if (count == 0L) { + this.model = null; + return; + } + if (is_mini_batch) { // Update model with accumulated delta + batchUpdate(); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + } + + protected final void runIterativeTraining(@Nonnegative final int iterations) + throws HiveException { + final ByteBuffer buf = this.inputBuf; + final NioStatefullSegment dst = this.fileIO; + assert (buf != null); + assert (dst != null); + final long numTrainingExamples = count; + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( + "hivemall.GeneralLearnerBase$Counter", "iteration"); + + try { + if (dst.getPosition() == 0L) {// run iterations w/o temporary file + if (buf.position() == 0) { + return; // no training example + } + buf.flip(); + + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); + reportProgress(reporter); + setCounterValue(iterCounter, iter); + + while (buf.remaining() > 0) { + int recordBytes = buf.getInt(); + assert (recordBytes > 0) : recordBytes; + int featureVectorLength = buf.getInt(); + final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; + for (int j = 0; j < featureVectorLength; j++) { + featureVector[j] = readFeatureValue(buf); + } + float target = buf.getFloat(); + train(featureVector, target); + } + buf.rewind(); + + if (is_mini_batch) { // Update model with accumulated delta + batchUpdate(); + } + + if (cvState.isConverged(numTrainingExamples)) { + break; } - Object k = itor.getKey(); - fv.set(probe.get()); - cov.set(probe.getCovariance()); - forwardMapObj[0] = k; - forwardMapObj[1] = fv; - forwardMapObj[2] = cov; - forward(forwardMapObj); - numForwarded++; } - } else { - final WeightValue probe = new WeightValue(); - final Object[] forwardMapObj = new Object[2]; - final FloatWritable fv = new FloatWritable(); - final IMapIterator itor = model.entries(); - while (itor.next() != -1) { - itor.getValue(probe); - if (!probe.isTouched()) { - continue; // skip outputting untouched weights + logger.info("Performed " + + cvState.getCurrentIteration() + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on memory (thus " + + NumberUtils.formatNumber(numTrainingExamples + * cvState.getCurrentIteration()) + " training updates in total) "); + } else {// read training examples in the temporary file and invoke train for each example + // write training examples in buffer to a temporary file + if (buf.remaining() > 0) { + writeBuffer(buf, dst); + } + try { + dst.flush(); + } catch (IOException e) { + throw new HiveException("Failed to flush a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (logger.isInfoEnabled()) { + File tmpFile = dst.getFile(); + logger.info("Wrote " + numTrainingExamples + + " records to a temporary file for iterative training: " + + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + + ")"); + } + + // run iterations + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); + setCounterValue(iterCounter, iter); + + buf.clear(); + dst.resetPosition(); + while (true) { + reportProgress(reporter); + // TODO prefetch + // writes training examples to a buffer in the temporary file + final int bytesRead; + try { + bytesRead = dst.read(buf); + } catch (IOException e) { + throw new HiveException("Failed to read a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (bytesRead == 0) { // reached file EOF + break; + } + assert (bytesRead > 0) : bytesRead; + + // reads training examples from a buffer + buf.flip(); + int remain = buf.remaining(); + if (remain < SizeOf.INT) { + throw new HiveException("Illegal file format was detected"); + } + while (remain >= SizeOf.INT) { + int pos = buf.position(); + int recordBytes = buf.getInt(); + remain -= SizeOf.INT; + + if (remain < recordBytes) { + buf.position(pos); + break; + } + + int featureVectorLength = buf.getInt(); + final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; + for (int j = 0; j < featureVectorLength; j++) { + featureVector[j] = readFeatureValue(buf); + } + float target = buf.getFloat(); + train(featureVector, target); + + remain -= recordBytes; + } + buf.compact(); + } + + if (is_mini_batch) { // Update model with accumulated delta + batchUpdate(); + } + + if (cvState.isConverged(numTrainingExamples)) { + break; } - Object k = itor.getKey(); - fv.set(probe.get()); - forwardMapObj[0] = k; - forwardMapObj[1] = fv; - forward(forwardMapObj); - numForwarded++; } + logger.info("Performed " + + cvState.getCurrentIteration() + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on a secondary storage (thus " + + NumberUtils.formatNumber(numTrainingExamples + * cvState.getCurrentIteration()) + " training updates in total)"); } - long numMixed = model.getNumMixed(); - this.model = null; - logger.info("Trained a prediction model using " + count + " training examples" - + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); - logger.info("Forwarded the prediction model of " + numForwarded + " rows"); + } catch (Throwable e) { + throw new HiveException("Exception caused in the iterative training", e); + } finally { + // delete the temporary file and release resources + try { + dst.close(true); + } catch (IOException e) { + throw new HiveException("Failed to close a file: " + + dst.getFile().getAbsolutePath(), e); + } + this.inputBuf = null; + this.fileIO = null; } } - @VisibleForTesting - public double getCumulativeLoss() { - return cumLoss; + protected void forwardModel() throws HiveException { + int numForwarded = 0; + if (useCovariance()) { + final WeightValueWithCovar probe = new WeightValueWithCovar(); + final Object[] forwardMapObj = new Object[3]; + final FloatWritable fv = new FloatWritable(); + final FloatWritable cov = new FloatWritable(); + final IMapIterator itor = model.entries(); + while (itor.next() != -1) { + itor.getValue(probe); + if (!probe.isTouched()) { + continue; // skip outputting untouched weights + } + Object k = itor.getKey(); + fv.set(probe.get()); + cov.set(probe.getCovariance()); + forwardMapObj[0] = k; + forwardMapObj[1] = fv; + forwardMapObj[2] = cov; + forward(forwardMapObj); + numForwarded++; + } + } else { + final WeightValue probe = new WeightValue(); + final Object[] forwardMapObj = new Object[2]; + final FloatWritable fv = new FloatWritable(); + final IMapIterator itor = model.entries(); + while (itor.next() != -1) { + itor.getValue(probe); + if (!probe.isTouched()) { + continue; // skip outputting untouched weights + } + Object k = itor.getKey(); + fv.set(probe.get()); + forwardMapObj[0] = k; + forwardMapObj[1] = fv; + forward(forwardMapObj); + numForwarded++; + } + } + long numMixed = model.getNumMixed(); + logger.info("Trained a prediction model using " + count + " training examples" + + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); + logger.info("Forwarded the prediction model of " + numForwarded + " rows"); } @VisibleForTesting - public void resetCumulativeLoss() { - this.cumLoss = 0.d; + public double getCumulativeLoss() { + return (cvState == null) ? Double.NaN : cvState.getCumulativeLoss(); } - } diff --git a/core/src/main/java/hivemall/UDTFWithOptions.java b/core/src/main/java/hivemall/UDTFWithOptions.java index 39ab233cd..b09cffaff 100644 --- a/core/src/main/java/hivemall/UDTFWithOptions.java +++ b/core/src/main/java/hivemall/UDTFWithOptions.java @@ -63,7 +63,7 @@ protected final Reporter getReporter() { return mapredContext.getReporter(); } - protected static void reportProgress(@Nonnull Reporter reporter) { + protected static void reportProgress(@Nullable Reporter reporter) { if (reporter != null) { synchronized (reporter) { reporter.progress(); diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java index 8e17de172..98cdf0bb5 100644 --- a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java @@ -42,7 +42,8 @@ public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF { @Override protected String getLossOptionDescription() { return "Loss function [HingeLoss (default), LogLoss, SquaredHingeLoss, ModifiedHuberLoss, or\n" - + "a regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]"; + + "a regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, " + + "SquaredEpsilonInsensitiveLoss, HuberLoss]"; } @Override diff --git a/core/src/main/java/hivemall/common/ConversionState.java b/core/src/main/java/hivemall/common/ConversionState.java index dd2066291..ff9224142 100644 --- a/core/src/main/java/hivemall/common/ConversionState.java +++ b/core/src/main/java/hivemall/common/ConversionState.java @@ -25,20 +25,19 @@ public final class ConversionState { private static final Log logger = LogFactory.getLog(ConversionState.class); /** Whether to check conversion */ - protected final boolean conversionCheck; + private final boolean conversionCheck; /** Threshold to determine convergence */ - protected final double convergenceRate; + private final double convergenceRate; /** being ready to end iteration */ - protected boolean readyToFinishIterations; + private boolean readyToFinishIterations; /** The cumulative errors in the training */ - protected double totalErrors; + private double totalErrors; /** The cumulative losses in an iteration */ - protected double currLosses, prevLosses; + private double currLosses, prevLosses; - protected int curIter; - protected float curEta; + private int curIter; public ConversionState() { this(true, 0.005d); @@ -51,8 +50,7 @@ public ConversionState(boolean conversionCheck, double convergenceRate) { this.totalErrors = 0.d; this.currLosses = 0.d; this.prevLosses = Double.POSITIVE_INFINITY; - this.curIter = 0; - this.curEta = Float.NaN; + this.curIter = 1; } public double getTotalErrors() { @@ -83,20 +81,16 @@ public boolean isLossIncreased() { return currLosses > prevLosses; } - public boolean isConverged(final int iter, final long obserbedTrainingExamples) { + public boolean isConverged(final long obserbedTrainingExamples) { if (conversionCheck == false) { - this.prevLosses = currLosses; - this.currLosses = 0.d; return false; } if (currLosses > prevLosses) { if (logger.isInfoEnabled()) { - logger.info("Iteration #" + iter + " currLoss `" + currLosses + "` > prevLosses `" - + prevLosses + '`'); + logger.info("Iteration #" + curIter + " currLoss `" + currLosses + + "` > prevLosses `" + prevLosses + '`'); } - this.prevLosses = currLosses; - this.currLosses = 0.d; this.readyToFinishIterations = false; return false; } @@ -105,7 +99,7 @@ public boolean isConverged(final int iter, final long obserbedTrainingExamples) if (changeRate < convergenceRate) { if (readyToFinishIterations) { // NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY - logger.info("Training converged at " + iter + "-th iteration. [curLosses=" + logger.info("Training converged at " + curIter + "-th iteration. [curLosses=" + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + ']'); return true; @@ -114,33 +108,24 @@ public boolean isConverged(final int iter, final long obserbedTrainingExamples) } } else { if (logger.isDebugEnabled()) { - logger.debug("Iteration #" + iter + " [curLosses=" + currLosses + ", prevLosses=" - + prevLosses + ", changeRate=" + changeRate + ", #trainingExamples=" - + obserbedTrainingExamples + ']'); + logger.debug("Iteration #" + curIter + " [curLosses=" + currLosses + + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + + ", #trainingExamples=" + obserbedTrainingExamples + ']'); } this.readyToFinishIterations = false; } - this.prevLosses = currLosses; - this.currLosses = 0.d; return false; } - public void logState(int iter, float eta) { - if (logger.isInfoEnabled()) { - logger.info("Iteration #" + iter + " [curLoss=" + currLosses + ", prevLoss=" - + prevLosses + ", eta=" + eta + ']'); - } - this.curIter = iter; - this.curEta = eta; + public void next() { + this.prevLosses = currLosses; + this.currLosses = 0.d; + this.curIter++; } public int getCurrentIteration() { return curIter; } - public float getCurrentEta() { - return curEta; - } - } diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java index 36af12747..65b6ba717 100644 --- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java +++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java @@ -20,11 +20,11 @@ import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; +import hivemall.fm.FMStringFeatureMapModel.Entry; import hivemall.optimizer.EtaEstimator; import hivemall.optimizer.LossFunctions; import hivemall.optimizer.LossFunctions.LossFunction; import hivemall.optimizer.LossFunctions.LossType; -import hivemall.fm.FMStringFeatureMapModel.Entry; import hivemall.utils.collections.IMapIterator; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; @@ -539,8 +539,8 @@ protected void runTrainingIteration(int iterations) throws HiveException { } inputBuf.flip(); - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + _cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); @@ -557,12 +557,12 @@ protected void runTrainingIteration(int iterations) throws HiveException { ++_t; train(x, y, adaregr); } - if (_cvState.isConverged(iter, numTrainingExamples)) { + if (_cvState.isConverged(numTrainingExamples)) { break; } inputBuf.rewind(); } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(_t) + " training updates in total) "); @@ -587,8 +587,8 @@ protected void runTrainingIteration(int iterations) throws HiveException { } // run iterations - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + _cvState.next(); setCounterValue(iterCounter, iter); inputBuf.clear(); @@ -639,11 +639,11 @@ protected void runTrainingIteration(int iterations) throws HiveException { } inputBuf.compact(); } - if (_cvState.isConverged(iter, numTrainingExamples)) { + if (_cvState.isConverged(numTrainingExamples)) { break; } } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + _cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(_t) + " training updates in total)"); diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java index 56a1992ac..141b2618f 100644 --- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java @@ -20,8 +20,8 @@ import hivemall.UDTFWithOptions; import hivemall.common.ConversionState; -import hivemall.optimizer.EtaEstimator; import hivemall.mf.FactorizedModel.RankInitScheme; +import hivemall.optimizer.EtaEstimator; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.FileUtils; import hivemall.utils.io.NioFixedSegment; @@ -479,8 +479,8 @@ private final void runIterativeTraining(@Nonnegative final int iterations) throw } inputBuf.flip(); - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); @@ -493,8 +493,7 @@ private final void runIterativeTraining(@Nonnegative final int iterations) throw train(u, i, j); } cvState.multiplyLoss(0.5d); - cvState.logState(iter, eta()); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { @@ -504,7 +503,7 @@ private final void runIterativeTraining(@Nonnegative final int iterations) throw } inputBuf.rewind(); } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(count) + " training updates in total) "); @@ -531,8 +530,8 @@ private final void runIterativeTraining(@Nonnegative final int iterations) throw } // run iterations - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); setCounterValue(iterCounter, iter); inputBuf.clear(); @@ -569,8 +568,7 @@ private final void runIterativeTraining(@Nonnegative final int iterations) throw inputBuf.compact(); } cvState.multiplyLoss(0.5d); - cvState.logState(iter, eta()); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } if (cvState.isLossIncreased()) { @@ -579,7 +577,7 @@ private final void runIterativeTraining(@Nonnegative final int iterations) throw etaEstimator.update(0.5f); } } - LOG.info("Performed " + Math.min(iter, iterations) + " iterations of " + LOG.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(count) + " training updates in total)"); diff --git a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java index bfc1f1996..66ec60d6f 100644 --- a/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java +++ b/core/src/main/java/hivemall/mf/OnlineMatrixFactorizationUDTF.java @@ -477,8 +477,8 @@ protected final void runIterativeTraining(@Nonnegative final int iterations) } inputBuf.flip(); - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); reportProgress(reporter); setCounterValue(iterCounter, iter); @@ -491,12 +491,12 @@ protected final void runIterativeTraining(@Nonnegative final int iterations) train(user, item, rating); } cvState.multiplyLoss(0.5d); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } inputBuf.rewind(); } - logger.info("Performed " + Math.min(iter, iterations) + " iterations of " + logger.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(count) + " training updates in total) "); @@ -523,8 +523,8 @@ protected final void runIterativeTraining(@Nonnegative final int iterations) } // run iterations - int iter = 2; - for (; iter <= iterations; iter++) { + for (int iter = 2; iter <= iterations; iter++) { + cvState.next(); setCounterValue(iterCounter, iter); inputBuf.clear(); @@ -561,11 +561,11 @@ protected final void runIterativeTraining(@Nonnegative final int iterations) inputBuf.compact(); } cvState.multiplyLoss(0.5d); - if (cvState.isConverged(iter, numTrainingExamples)) { + if (cvState.isConverged(numTrainingExamples)) { break; } } - logger.info("Performed " + Math.min(iter, iterations) + " iterations of " + logger.info("Performed " + cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples using a secondary storage (thus " + NumberUtils.formatNumber(count) + " training updates in total)"); diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index 11aa8f007..ecba9a724 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -28,7 +28,7 @@ public final class FeatureValue { - private/* final */Object feature; + private/* final */Object feature; // possible types: String, Text, Integer, Long private/* final */double value; public FeatureValue() {}// used for Probe @@ -163,5 +163,4 @@ public static void parseFeatureAsString(@Nonnull final String s, probe.value = 1.d; } } - } diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index bbd2320d3..0f8283315 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -51,7 +51,7 @@ static abstract class OptimizerBase implements Optimizer { @Nonnull protected final Regularization _reg; @Nonnegative - protected int _numStep = 1; + protected long _numStep = 1L; public OptimizerBase(@Nonnull Map options) { this._eta = EtaEstimator.get(options); diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java index 1bd9393a4..a34a6e665 100644 --- a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java +++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java @@ -41,8 +41,8 @@ public final class GeneralRegressionUDTF extends GeneralLearnerBaseUDTF { @Override protected String getLossOptionDescription() { - return "Loss function [default: SquaredLoss/squared, QuantileLoss/quantile, " - + "EpsilonInsensitiveLoss/epsilon_insensitive, HuberLoss/huber]"; + return "Loss function [SquaredLoss (default), QuantileLoss, EpsilonInsensitiveLoss, " + + "SquaredEpsilonInsensitiveLoss, HuberLoss]"; } @Override diff --git a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java index 1c7a90e06..34de2d902 100644 --- a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java +++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java @@ -35,11 +35,16 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; + import org.junit.Assert; import org.junit.Test; @@ -82,6 +87,103 @@ public void testUnsupportedRegularization() throws Exception { udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); } + @Test + public void testNoOptions() throws Exception { + List x = Arrays.asList("1:-2", "2:-1"); + int y = 0; + + GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI}); + + udtf.process(new Object[] {x, y}); + + udtf.finalizeTraining(); + + float score = udtf.predict(udtf.parseFeatures(x)); + int predicted = score > 0.f ? 1 : 0; + Assert.assertTrue(y == predicted); + } + + private void testFeature(List x, ObjectInspector featureOI, Class featureClass) + throws Exception { + int y = 0; + + GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + ListObjectInspector featureListOI = ObjectInspectorFactory.getStandardListObjectInspector(featureOI); + + udtf.initialize(new ObjectInspector[] {featureListOI, valueOI}); + + final List modelFeatures = new ArrayList(); + udtf.setCollector(new Collector() { + @Override + public void collect(Object input) throws HiveException { + Object[] forwardMapObj = (Object[]) input; + modelFeatures.add(forwardMapObj[0]); + } + }); + + udtf.process(new Object[] {x, y}); + + udtf.close(); + + Class modelFeatureClass = modelFeatures.get(0).getClass(); + for (Object modelFeature : modelFeatures) { + Assert.assertEquals("All model features must have same type", modelFeatureClass, + modelFeature.getClass()); + } + + Assert.assertEquals( + "Model feature must correspond to UDTF output feature's object inspector", + featureClass, modelFeatureClass); + } + + @Test + public void testStringFeature() throws Exception { + List x = Arrays.asList("1:-2", "2:-1"); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + testFeature(x, featureOI, String.class); + } + + @Test + public void testTextFeature() throws Exception { + List x = Arrays.asList(new Text("1:-2"), new Text("2:-1")); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector; + testFeature(x, featureOI, Text.class); + } + + @Test + public void testIntegerFeature() throws Exception { + List x = Arrays.asList(111, 222); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + testFeature(x, featureOI, Integer.class); + } + + @Test + public void testWritableIntFeature() throws Exception { + List x = Arrays.asList(new IntWritable(111), new IntWritable(222)); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; + testFeature(x, featureOI, IntWritable.class); + } + + @Test + public void testLongFeature() throws Exception { + List x = Arrays.asList(111L, 222L); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector; + testFeature(x, featureOI, Long.class); + } + + @Test + public void testWritableLongFeature() throws Exception { + List x = Arrays.asList(new LongWritable(111L), new LongWritable(222L)); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; + testFeature(x, featureOI, LongWritable.class); + } + private void run(@Nonnull String options) throws Exception { println(options); @@ -95,8 +197,6 @@ private void run(@Nonnull String options) throws Exception { int[] labels = new int[] {0, 0, 0, 1, 1, 1}; - int maxIter = 512; - GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; @@ -106,19 +206,17 @@ private void run(@Nonnull String options) throws Exception { udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); - double cumLossPrev = Double.MAX_VALUE; - double cumLoss = 0.d; - int it = 0; - while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) { - cumLossPrev = cumLoss; - udtf.resetCumulativeLoss(); - for (int i = 0, size = samplesList.size(); i < size; i++) { - udtf.process(new Object[] {samplesList.get(i), labels[i]}); - } - cumLoss = udtf.getCumulativeLoss(); - println("Iter: " + ++it + ", Cumulative loss: " + cumLoss); + for (int i = 0, size = samplesList.size(); i < size; i++) { + udtf.process(new Object[] {samplesList.get(i), labels[i]}); } - Assert.assertTrue(cumLoss / samplesList.size() < 0.5d); + + udtf.finalizeTraining(); + + double cumLoss = udtf.getCumulativeLoss(); + println("Cumulative loss: " + cumLoss); + double normalizedLoss = cumLoss / samplesList.size(); + Assert.assertTrue("cumLoss: " + cumLoss + ", normalizedLoss: " + normalizedLoss + + "\noptions: " + options, normalizedLoss < 0.5d); int numTests = 0; int numCorrect = 0; @@ -157,7 +255,8 @@ public void test() throws Exception { } for (String loss : lossFunctions) { - String options = "-opt " + opt + " -reg " + reg + " -loss " + loss; + String options = "-opt " + opt + " -reg " + reg + " -loss " + loss + + " -cv_rate 0.001 -iter 512"; // sparse run(options); @@ -178,15 +277,13 @@ public void test() throws Exception { @SuppressWarnings("unchecked") @Test public void testNews20() throws IOException, ParseException, HiveException { - int nIter = 10; - GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaStringObjectInspector, - "-opt SGD -loss logloss -reg L2 -lambda 0.1"); + "-opt SGD -loss logloss -reg L2 -lambda 0.1 -cv_rate 0.005"); udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params}); @@ -213,13 +310,7 @@ public void testNews20() throws IOException, ParseException, HiveException { news20.close(); // perform SGD iterations - for (int it = 1; it < nIter; it++) { - for (int i = 0, size = wordsList.size(); i < size; i++) { - words = wordsList.get(i); - int label = labels.get(i); - udtf.process(new Object[] {words, label}); - } - } + udtf.finalizeTraining(); int numTests = 0; int numCorrect = 0; diff --git a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java index cfe9651bf..e39a22613 100644 --- a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java +++ b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java @@ -25,11 +25,17 @@ import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; + import org.junit.Assert; import org.junit.Test; @@ -84,6 +90,102 @@ public void testUnsupportedRegularization() throws Exception { udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); } + @Test + public void testNoOptions() throws Exception { + List x = Arrays.asList("1:-2", "2:-1"); + float y = 0.f; + + GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); + ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(stringOI); + + udtf.initialize(new ObjectInspector[] {stringListOI, intOI}); + + udtf.process(new Object[] {x, y}); + + udtf.finalizeTraining(); + + float predicted = udtf.predict(udtf.parseFeatures(x)); + Assert.assertEquals(y, predicted, 1E-5); + } + + private void testFeature(List x, ObjectInspector featureOI, Class featureClass) + throws Exception { + float y = 0.f; + + GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); + ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + ListObjectInspector featureListOI = ObjectInspectorFactory.getStandardListObjectInspector(featureOI); + + udtf.initialize(new ObjectInspector[] {featureListOI, valueOI}); + + final List modelFeatures = new ArrayList(); + udtf.setCollector(new Collector() { + @Override + public void collect(Object input) throws HiveException { + Object[] forwardMapObj = (Object[]) input; + modelFeatures.add(forwardMapObj[0]); + } + }); + + udtf.process(new Object[] {x, y}); + + udtf.close(); + + Class modelFeatureClass = modelFeatures.get(0).getClass(); + for (Object modelFeature : modelFeatures) { + Assert.assertEquals("All model features must have same type", modelFeatureClass, + modelFeature.getClass()); + } + + Assert.assertEquals( + "Model feature must correspond to UDTF output feature's object inspector", + featureClass, modelFeatureClass); + } + + @Test + public void testStringFeature() throws Exception { + List x = Arrays.asList("1:-2", "2:-1"); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + testFeature(x, featureOI, String.class); + } + + @Test + public void testTextFeature() throws Exception { + List x = Arrays.asList(new Text("1:-2"), new Text("2:-1")); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector; + testFeature(x, featureOI, Text.class); + } + + @Test + public void testIntegerFeature() throws Exception { + List x = Arrays.asList(111, 222); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + testFeature(x, featureOI, Integer.class); + } + + @Test + public void testWritableIntFeature() throws Exception { + List x = Arrays.asList(new IntWritable(111), new IntWritable(222)); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; + testFeature(x, featureOI, IntWritable.class); + } + + @Test + public void testLongFeature() throws Exception { + List x = Arrays.asList(111L, 222L); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector; + testFeature(x, featureOI, Long.class); + } + + @Test + public void testWritableLongFeature() throws Exception { + List x = Arrays.asList(new LongWritable(111L), new LongWritable(222L)); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; + testFeature(x, featureOI, LongWritable.class); + } + private void run(@Nonnull String options) throws Exception { println(options); @@ -108,9 +210,6 @@ private void run(@Nonnull String options) throws Exception { x2 += x2Step; } - int numTrain = (int) (numSamples * 0.8); - int maxIter = 512; - GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); ObjectInspector floatOI = PrimitiveObjectInspectorFactory.javaFloatObjectInspector; ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; @@ -120,23 +219,29 @@ private void run(@Nonnull String options) throws Exception { udtf.initialize(new ObjectInspector[] {stringListOI, floatOI, params}); - double cumLossPrev = Double.MAX_VALUE; - double cumLoss = 0.d; - int it = 0; - while ((it < maxIter) && (Math.abs(cumLoss - cumLossPrev) > 1e-3f)) { - cumLossPrev = cumLoss; - udtf.resetCumulativeLoss(); - for (int i = 0; i < numTrain; i++) { - udtf.process(new Object[] {samplesList.get(i), (Float) ys.get(i)}); - } - cumLoss = udtf.getCumulativeLoss(); - println("Iter: " + ++it + ", Cumulative loss: " + cumLoss); + float accum = 0.f; + for (int i = 0; i < numSamples; i++) { + float y = ys.get(i).floatValue(); + float predicted = udtf.predict(udtf.parseFeatures(samplesList.get(i))); + accum += Math.abs(y - predicted); } - Assert.assertTrue(cumLoss / numTrain < 0.1d); + float maeInit = accum / numSamples; + println("Mean absolute error before training: " + maeInit); - float accum = 0.f; + for (int i = 0; i < numSamples; i++) { + udtf.process(new Object[] {samplesList.get(i), (Float) ys.get(i)}); + } + + udtf.finalizeTraining(); - for (int i = numTrain; i < numSamples; i++) { + double cumLoss = udtf.getCumulativeLoss(); + println("Cumulative loss: " + cumLoss); + double normalizedLoss = cumLoss / numSamples; + Assert.assertTrue("cumLoss: " + cumLoss + ", normalizedLoss: " + normalizedLoss + + "\noptions: " + options, normalizedLoss < 0.1d); + + accum = 0.f; + for (int i = 0; i < numSamples; i++) { float y = ys.get(i).floatValue(); float predicted = udtf.predict(udtf.parseFeatures(samplesList.get(i))); @@ -144,10 +249,10 @@ private void run(@Nonnull String options) throws Exception { accum += Math.abs(y - predicted); } - - float err = accum / (numSamples - numTrain); - println("Mean absolute error: " + err); - Assert.assertTrue(err < 0.2f); + float mae = accum / numSamples; + println("Mean absolute error after training: " + mae); + Assert.assertTrue("accum: " + accum + ", mae (init):" + maeInit + ", mae:" + mae + + "\noptions: " + options, mae < maeInit); } @Test @@ -165,7 +270,7 @@ public void test() throws Exception { for (String loss : lossFunctions) { String options = "-opt " + opt + " -reg " + reg + " -loss " + loss - + " -lambda 1e-6 -eta0 1e-1"; + + " -iter 512"; // sparse run(options); From 7bd60a3f1da934721f390f86873d08bc5393b199 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Tue, 11 Jul 2017 16:00:56 +0900 Subject: [PATCH 02/13] Refactored read/writeFeatureValue to be more straightfoward --- .../java/hivemall/GeneralLearnerBaseUDTF.java | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index f8d56be8f..8dc02dff0 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -64,14 +64,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaIntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaLongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaStringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaStringObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaIntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaLongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; @@ -356,29 +356,38 @@ protected void recordTrainSampleToTempFile(@Nonnull final FeatureValue[] feature buf.putFloat(target); } - private void writeFeatureValue(@Nonnull ByteBuffer buf, @Nonnull FeatureValue f) { + private static void writeFeatureValue(@Nonnull final ByteBuffer buf, + @Nonnull final FeatureValue f) { NIOUtils.putString(f.getFeatureAsString(), buf); buf.putDouble(f.getValue()); } - private FeatureValue readFeatureValue(@Nonnull ByteBuffer buf) { - Object feature = NIOUtils.getString(buf); + @Nonnull + private static FeatureValue readFeatureValue(@Nonnull final ByteBuffer buf, + @Nonnull final FeatureType featureType) { + final String featureStr = NIOUtils.getString(buf); + final Object feature; switch (featureType) { + case JavaString: + feature = featureStr; + break; case Text: - feature = new Text((String) feature); + feature = new Text(featureStr); break; case JavaInteger: - feature = Integer.parseInt((String) feature); + feature = Integer.valueOf(featureStr); break; case WritableInt: - feature = new IntWritable(Integer.parseInt((String) feature)); + feature = new IntWritable(Integer.parseInt(featureStr)); break; case JavaLong: - feature = Long.parseLong((String) feature); + feature = Long.valueOf(featureStr); break; case WritableLong: - feature = new LongWritable(Long.parseLong((String) feature)); + feature = new LongWritable(Long.parseLong(featureStr)); break; + default: + throw new IllegalStateException("Unexpected feature type: " + featureType); } double value = buf.getDouble(); return new FeatureValue(feature, value); @@ -561,7 +570,7 @@ protected final void runIterativeTraining(@Nonnegative final int iterations) int featureVectorLength = buf.getInt(); final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; for (int j = 0; j < featureVectorLength; j++) { - featureVector[j] = readFeatureValue(buf); + featureVector[j] = readFeatureValue(buf, featureType); } float target = buf.getFloat(); train(featureVector, target); @@ -644,7 +653,7 @@ protected final void runIterativeTraining(@Nonnegative final int iterations) int featureVectorLength = buf.getInt(); final FeatureValue[] featureVector = new FeatureValue[featureVectorLength]; for (int j = 0; j < featureVectorLength; j++) { - featureVector[j] = readFeatureValue(buf); + featureVector[j] = readFeatureValue(buf, featureType); } float target = buf.getFloat(); train(featureVector, target); From 3f5087311ad5ffb3f09e556fb84710503ac99bb7 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Thu, 13 Jul 2017 16:33:24 +0900 Subject: [PATCH 03/13] Removed IDE warning --- .../java/hivemall/classifier/GeneralClassifierUDTFTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java index 34de2d902..e973a4d63 100644 --- a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java +++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java @@ -108,7 +108,7 @@ public void testNoOptions() throws Exception { Assert.assertTrue(y == predicted); } - private void testFeature(List x, ObjectInspector featureOI, Class featureClass) + private void testFeature(List x, ObjectInspector featureOI, Class featureClass) throws Exception { int y = 0; @@ -131,7 +131,7 @@ public void collect(Object input) throws HiveException { udtf.close(); - Class modelFeatureClass = modelFeatures.get(0).getClass(); + Class modelFeatureClass = modelFeatures.get(0).getClass(); for (Object modelFeature : modelFeatures) { Assert.assertEquals("All model features must have same type", modelFeatureClass, modelFeature.getClass()); From 26bdd2ef36ae688f647b8aa8dbdff23b1737e436 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Thu, 13 Jul 2017 16:35:58 +0900 Subject: [PATCH 04/13] Fixed to support various OI such as LazyStringObjectInspector --- .../java/hivemall/GeneralLearnerBaseUDTF.java | 58 ++++++------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 8dc02dff0..095795688 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -62,20 +62,17 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaIntObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaLongObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaStringObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableStringObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters; import org.apache.hadoop.mapred.Reporter; @@ -83,13 +80,12 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); public enum FeatureType { - JavaString, Text, JavaInteger, WritableInt, JavaLong, WritableLong + JavaString, WritableInt, WritableLong } private ListObjectInspector featureListOI; private PrimitiveObjectInspector featureInputOI; private PrimitiveObjectInspector targetOI; - private boolean parseFeature; private FeatureType featureType; // ----------------------------------------- @@ -235,24 +231,17 @@ protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector ar this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); HiveUtils.validateFeatureOI(featureRawOI); - if (featureRawOI instanceof JavaStringObjectInspector) { + if (featureRawOI instanceof StringObjectInspector) { this.featureType = FeatureType.JavaString; - } else if (featureRawOI instanceof WritableStringObjectInspector) { - this.featureType = FeatureType.Text; - } else if (featureRawOI instanceof JavaIntObjectInspector) { - this.featureType = FeatureType.JavaInteger; - } else if (featureRawOI instanceof WritableIntObjectInspector) { + } else if (featureRawOI instanceof IntObjectInspector) { this.featureType = FeatureType.WritableInt; - } else if (featureRawOI instanceof JavaLongObjectInspector) { - this.featureType = FeatureType.JavaLong; - } else if (featureRawOI instanceof WritableLongObjectInspector) { + } else if (featureRawOI instanceof LongObjectInspector) { this.featureType = FeatureType.WritableLong; } else { - throw new UDFArgumentException("Feature object inspector must be one of " - + "[JavaString, WritableString, JavaInt, WritableInt, Long, WritableLong]: " - + featureRawOI.toString()); + throw new UDFArgumentException( + "Feature object inspector must be one of [Text, Int, BitInt]: " + + featureRawOI.toString()); } - this.parseFeature = HiveUtils.isStringOI(featureRawOI); return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } @@ -327,10 +316,10 @@ protected void recordTrainSampleToTempFile(@Nonnull final FeatureValue[] feature if (f == null) { continue; } - String feature = f.getFeatureAsString(); + int featureLength = f.getFeatureAsString().length(); // feature as String (even if it is Text or Integer) - featureVectorBytes += SizeOf.CHAR * feature.length(); + featureVectorBytes += SizeOf.CHAR * featureLength; // NIOUtils.putString() first puts the length of string before string itself featureVectorBytes += SizeOf.INT; @@ -371,18 +360,9 @@ private static FeatureValue readFeatureValue(@Nonnull final ByteBuffer buf, case JavaString: feature = featureStr; break; - case Text: - feature = new Text(featureStr); - break; - case JavaInteger: - feature = Integer.valueOf(featureStr); - break; case WritableInt: feature = new IntWritable(Integer.parseInt(featureStr)); break; - case JavaLong: - feature = Long.valueOf(featureStr); - break; case WritableLong: feature = new LongWritable(Long.parseLong(featureStr)); break; @@ -408,14 +388,12 @@ public final FeatureValue[] parseFeatures(@Nonnull final List features) { continue; } final FeatureValue fv; - if (parseFeature) { - if (featureType == FeatureType.JavaString) { - fv = FeatureValue.parseFeatureAsString((String) f); - } else { - fv = FeatureValue.parse(f); // = parse feature as Text - } + if (featureType == FeatureType.JavaString) { + String s = f.toString(); + fv = FeatureValue.parse(s); } else { - Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); + Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector, + ObjectInspectorCopyOption.WRITABLE); // should be IntWritable or LongWritable fv = new FeatureValue(k, 1.f); } featureVector[i] = fv; From eef5c328b02cbc3c720a1cd3f26d3b4c3a4f1e14 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Thu, 13 Jul 2017 17:01:39 +0900 Subject: [PATCH 05/13] Fixed to hold a feature as a Primitive Java object because ObjectInspectorUtils.copyToStandardObject(f, featureInspector, ObjectInspectorCopyOption.WRITABLE) is inefficient --- .../java/hivemall/GeneralLearnerBaseUDTF.java | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 095795688..6bf38ce11 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -71,8 +71,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; import org.apache.hadoop.io.FloatWritable; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.mapred.Counters; import org.apache.hadoop.mapred.Reporter; @@ -80,7 +78,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); public enum FeatureType { - JavaString, WritableInt, WritableLong + STRING, INT, LONG } private ListObjectInspector featureListOI; @@ -230,17 +228,16 @@ protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector ar throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); - HiveUtils.validateFeatureOI(featureRawOI); if (featureRawOI instanceof StringObjectInspector) { - this.featureType = FeatureType.JavaString; + this.featureType = FeatureType.STRING; } else if (featureRawOI instanceof IntObjectInspector) { - this.featureType = FeatureType.WritableInt; + this.featureType = FeatureType.INT; } else if (featureRawOI instanceof LongObjectInspector) { - this.featureType = FeatureType.WritableLong; + this.featureType = FeatureType.LONG; } else { - throw new UDFArgumentException( - "Feature object inspector must be one of [Text, Int, BitInt]: " - + featureRawOI.toString()); + throw new UDFArgumentException("Feature object inspector must be one of " + + "[StringObjectInspector, IntObjectInspector, LongObjectInspector]: " + + featureRawOI.toString()); } return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } @@ -357,17 +354,17 @@ private static FeatureValue readFeatureValue(@Nonnull final ByteBuffer buf, final String featureStr = NIOUtils.getString(buf); final Object feature; switch (featureType) { - case JavaString: + case STRING: feature = featureStr; break; - case WritableInt: - feature = new IntWritable(Integer.parseInt(featureStr)); + case INT: + feature = Integer.valueOf(featureStr); break; - case WritableLong: - feature = new LongWritable(Long.parseLong(featureStr)); + case LONG: + feature = Long.valueOf(featureStr); break; default: - throw new IllegalStateException("Unexpected feature type: " + featureType); + throw new IllegalStateException("Unexpected feature type " + featureType + " for feature: " + featureStr); } double value = buf.getDouble(); return new FeatureValue(feature, value); @@ -388,12 +385,12 @@ public final FeatureValue[] parseFeatures(@Nonnull final List features) { continue; } final FeatureValue fv; - if (featureType == FeatureType.JavaString) { + if (featureType == FeatureType.STRING) { String s = f.toString(); - fv = FeatureValue.parse(s); + fv = FeatureValue.parseFeatureAsString(s); } else { Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector, - ObjectInspectorCopyOption.WRITABLE); // should be IntWritable or LongWritable + ObjectInspectorCopyOption.JAVA); // should be Integer or Long fv = new FeatureValue(k, 1.f); } featureVector[i] = fv; From a375e8fb083de05c515e44dad0b4af768013ef2f Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Thu, 13 Jul 2017 18:31:58 +0900 Subject: [PATCH 06/13] Changed error message of parseFeature() more freindly one and fixed unit tests --- .../java/hivemall/model/FeatureValue.java | 19 +++++- .../java/hivemall/utils/hadoop/HiveUtils.java | 44 +++++++++++++ .../classifier/GeneralClassifierUDTFTest.java | 61 ++++++++++++++----- .../regression/GeneralRegressionUDTFTest.java | 60 ++++++++++++++---- 4 files changed, 154 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index ecba9a724..11005e991 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -108,7 +108,11 @@ public static FeatureValue parse(@Nonnull final String s, final boolean mhash) String s1 = s.substring(0, pos); String s2 = s.substring(pos + 1); feature = mhash ? Integer.valueOf(MurmurHash3.murmurhash3(s1)) : new Text(s1); - weight = Double.parseDouble(s2); + try { + weight = Double.parseDouble(s2); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a feature value: " + s, e); + } } else { feature = mhash ? Integer.valueOf(MurmurHash3.murmurhash3(s)) : new Text(s); weight = 1.d; @@ -135,7 +139,11 @@ public static FeatureValue parseFeatureAsString(@Nonnull final String s) if (pos > 0) { feature = s.substring(0, pos); String s2 = s.substring(pos + 1); - weight = Double.parseDouble(s2); + try { + weight = Double.parseDouble(s2); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a feature value: " + s, e); + } } else { feature = s; weight = 1.d; @@ -157,10 +165,15 @@ public static void parseFeatureAsString(@Nonnull final String s, if (pos > 0) { probe.feature = s.substring(0, pos); String s2 = s.substring(pos + 1); - probe.value = Double.parseDouble(s2); + try { + probe.value = Double.parseDouble(s2); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Failed to parse a feature value: " + s, e); + } } else { probe.feature = s; probe.value = 1.d; } } + } diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 4ed1f123d..85e4be20b 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -28,6 +28,7 @@ import static hivemall.HivemallConstants.STRING_TYPE_NAME; import static hivemall.HivemallConstants.TINYINT_TYPE_NAME; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.BitSet; import java.util.Collections; @@ -46,10 +47,14 @@ import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef; import org.apache.hadoop.hive.serde2.lazy.LazyDouble; import org.apache.hadoop.hive.serde2.lazy.LazyInteger; +import org.apache.hadoop.hive.serde2.lazy.LazyLong; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.lazy.LazyString; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyStringObjectInspector; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; @@ -1037,4 +1042,43 @@ public static Object castLazyBinaryObject(@Nonnull final Object obj) { } return obj; } + + @Nonnull + public static LazyString lazyString(@Nonnull final String str) { + return lazyString(str, (byte) '\\'); + } + + @Nonnull + public static LazyString lazyString(@Nonnull final String str, final byte escapeChar) { + LazyStringObjectInspector oi = LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector( + false, escapeChar); + return lazyString(str, oi); + } + + @Nonnull + public static LazyString lazyString(@Nonnull final String str, + @Nonnull final LazyStringObjectInspector oi) { + LazyString lazy = new LazyString(oi); + ByteArrayRef ref = new ByteArrayRef(); + byte[] data = str.getBytes(StandardCharsets.UTF_8); + ref.setData(data); + lazy.init(ref, 0, data.length); + return lazy; + } + + @Nonnull + public static LazyInteger lazyInteger(@Nonnull final int v) { + LazyInteger lazy = new LazyInteger( + LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR); + lazy.getWritableObject().set(v); + return lazy; + } + + @Nonnull + public static LazyLong lazyLong(@Nonnull final long v) { + LazyLong lazy = new LazyLong(LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR); + lazy.getWritableObject().set(v); + return lazy; + } + } diff --git a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java index e973a4d63..dba4a0021 100644 --- a/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java +++ b/core/src/test/java/hivemall/classifier/GeneralClassifierUDTFTest.java @@ -18,6 +18,9 @@ */ package hivemall.classifier; +import static hivemall.utils.hadoop.HiveUtils.lazyInteger; +import static hivemall.utils.hadoop.HiveUtils.lazyLong; +import static hivemall.utils.hadoop.HiveUtils.lazyString; import hivemall.utils.math.MathUtils; import java.io.BufferedReader; @@ -36,6 +39,11 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.Collector; +import org.apache.hadoop.hive.serde2.lazy.LazyInteger; +import org.apache.hadoop.hive.serde2.lazy.LazyLong; +import org.apache.hadoop.hive.serde2.lazy.LazyString; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyStringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -44,7 +52,6 @@ import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; - import org.junit.Assert; import org.junit.Test; @@ -108,8 +115,8 @@ public void testNoOptions() throws Exception { Assert.assertTrue(y == predicted); } - private void testFeature(List x, ObjectInspector featureOI, Class featureClass) - throws Exception { + private void testFeature(@Nonnull List x, @Nonnull ObjectInspector featureOI, + @Nonnull Class featureClass, @Nonnull Class modelFeatureClass) throws Exception { int y = 0; GeneralClassifierUDTF udtf = new GeneralClassifierUDTF(); @@ -131,57 +138,83 @@ public void collect(Object input) throws HiveException { udtf.close(); - Class modelFeatureClass = modelFeatures.get(0).getClass(); + Assert.assertFalse(modelFeatures.isEmpty()); for (Object modelFeature : modelFeatures) { Assert.assertEquals("All model features must have same type", modelFeatureClass, modelFeature.getClass()); } - - Assert.assertEquals( - "Model feature must correspond to UDTF output feature's object inspector", - featureClass, modelFeatureClass); } @Test public void testStringFeature() throws Exception { List x = Arrays.asList("1:-2", "2:-1"); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; - testFeature(x, featureOI, String.class); + testFeature(x, featureOI, String.class, String.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalStringFeature() throws Exception { + List x = Arrays.asList("1:-2jjjjj", "2:-1"); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + testFeature(x, featureOI, String.class, String.class); + } + + @Test + public void testLazyStringFeature() throws Exception { + LazyStringObjectInspector oi = LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector( + false, (byte) 0); + List x = Arrays.asList(lazyString("テスト:-2", oi), lazyString("漢字:-333.0", oi), + lazyString("test:-1")); + testFeature(x, oi, LazyString.class, String.class); } @Test public void testTextFeature() throws Exception { List x = Arrays.asList(new Text("1:-2"), new Text("2:-1")); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector; - testFeature(x, featureOI, Text.class); + testFeature(x, featureOI, Text.class, String.class); } @Test public void testIntegerFeature() throws Exception { List x = Arrays.asList(111, 222); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; - testFeature(x, featureOI, Integer.class); + testFeature(x, featureOI, Integer.class, Integer.class); + } + + @Test + public void testLazyIntegerFeature() throws Exception { + List x = Arrays.asList(lazyInteger(111), lazyInteger(222)); + ObjectInspector featureOI = LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR; + testFeature(x, featureOI, LazyInteger.class, Integer.class); } @Test public void testWritableIntFeature() throws Exception { List x = Arrays.asList(new IntWritable(111), new IntWritable(222)); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; - testFeature(x, featureOI, IntWritable.class); + testFeature(x, featureOI, IntWritable.class, Integer.class); } @Test public void testLongFeature() throws Exception { List x = Arrays.asList(111L, 222L); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector; - testFeature(x, featureOI, Long.class); + testFeature(x, featureOI, Long.class, Long.class); + } + + @Test + public void testLazyLongFeature() throws Exception { + List x = Arrays.asList(lazyLong(111), lazyLong(222)); + ObjectInspector featureOI = LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR; + testFeature(x, featureOI, LazyLong.class, Long.class); } @Test public void testWritableLongFeature() throws Exception { List x = Arrays.asList(new LongWritable(111L), new LongWritable(222L)); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; - testFeature(x, featureOI, LongWritable.class); + testFeature(x, featureOI, LongWritable.class, Long.class); } private void run(@Nonnull String options) throws Exception { diff --git a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java index e39a22613..f352b8975 100644 --- a/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java +++ b/core/src/test/java/hivemall/regression/GeneralRegressionUDTFTest.java @@ -18,6 +18,10 @@ */ package hivemall.regression; +import static hivemall.utils.hadoop.HiveUtils.lazyInteger; +import static hivemall.utils.hadoop.HiveUtils.lazyLong; +import static hivemall.utils.hadoop.HiveUtils.lazyString; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -27,6 +31,11 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.Collector; +import org.apache.hadoop.hive.serde2.lazy.LazyInteger; +import org.apache.hadoop.hive.serde2.lazy.LazyLong; +import org.apache.hadoop.hive.serde2.lazy.LazyString; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyStringObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -35,7 +44,6 @@ import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; - import org.junit.Assert; import org.junit.Test; @@ -110,8 +118,8 @@ public void testNoOptions() throws Exception { Assert.assertEquals(y, predicted, 1E-5); } - private void testFeature(List x, ObjectInspector featureOI, Class featureClass) - throws Exception { + private void testFeature(@Nonnull List x, @Nonnull ObjectInspector featureOI, + @Nonnull Class featureClass, @Nonnull Class modelFeatureClass) throws Exception { float y = 0.f; GeneralRegressionUDTF udtf = new GeneralRegressionUDTF(); @@ -133,57 +141,83 @@ public void collect(Object input) throws HiveException { udtf.close(); - Class modelFeatureClass = modelFeatures.get(0).getClass(); + Assert.assertFalse(modelFeatures.isEmpty()); for (Object modelFeature : modelFeatures) { Assert.assertEquals("All model features must have same type", modelFeatureClass, modelFeature.getClass()); } + } - Assert.assertEquals( - "Model feature must correspond to UDTF output feature's object inspector", - featureClass, modelFeatureClass); + @Test + public void testLazyStringFeature() throws Exception { + LazyStringObjectInspector oi = LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector( + false, (byte) 0); + List x = Arrays.asList(lazyString("テスト:-2", oi), lazyString("漢字:-333.0", oi), + lazyString("test:-1")); + testFeature(x, oi, LazyString.class, String.class); } @Test public void testStringFeature() throws Exception { List x = Arrays.asList("1:-2", "2:-1"); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; - testFeature(x, featureOI, String.class); + testFeature(x, featureOI, String.class, String.class); + } + + @Test(expected = IllegalArgumentException.class) + public void testIlleagalStringFeature() throws Exception { + List x = Arrays.asList("1:-2jjjj", "2:-1"); + ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + testFeature(x, featureOI, String.class, String.class); } @Test public void testTextFeature() throws Exception { List x = Arrays.asList(new Text("1:-2"), new Text("2:-1")); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector; - testFeature(x, featureOI, Text.class); + testFeature(x, featureOI, Text.class, String.class); } @Test public void testIntegerFeature() throws Exception { List x = Arrays.asList(111, 222); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; - testFeature(x, featureOI, Integer.class); + testFeature(x, featureOI, Integer.class, Integer.class); + } + + @Test + public void testLazyIntegerFeature() throws Exception { + List x = Arrays.asList(lazyInteger(111), lazyInteger(222)); + ObjectInspector featureOI = LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR; + testFeature(x, featureOI, LazyInteger.class, Integer.class); } @Test public void testWritableIntFeature() throws Exception { List x = Arrays.asList(new IntWritable(111), new IntWritable(222)); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector; - testFeature(x, featureOI, IntWritable.class); + testFeature(x, featureOI, IntWritable.class, Integer.class); } @Test public void testLongFeature() throws Exception { List x = Arrays.asList(111L, 222L); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector; - testFeature(x, featureOI, Long.class); + testFeature(x, featureOI, Long.class, Long.class); + } + + @Test + public void testLazyLongFeature() throws Exception { + List x = Arrays.asList(lazyLong(111), lazyLong(222)); + ObjectInspector featureOI = LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR; + testFeature(x, featureOI, LazyLong.class, Long.class); } @Test public void testWritableLongFeature() throws Exception { List x = Arrays.asList(new LongWritable(111L), new LongWritable(222L)); ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector; - testFeature(x, featureOI, LongWritable.class); + testFeature(x, featureOI, LongWritable.class, Long.class); } private void run(@Nonnull String options) throws Exception { From 1997c39652a55a7fde7874ff137e2d311f3677f3 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Fri, 14 Jul 2017 15:12:33 +0900 Subject: [PATCH 07/13] Fixed CI error caused by changes in Feature.parse() --- core/src/test/java/hivemall/model/FeatureValueTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/java/hivemall/model/FeatureValueTest.java b/core/src/test/java/hivemall/model/FeatureValueTest.java index 2b6c832e5..598e13a8d 100644 --- a/core/src/test/java/hivemall/model/FeatureValueTest.java +++ b/core/src/test/java/hivemall/model/FeatureValueTest.java @@ -50,12 +50,12 @@ public void testParseWithWeight() { } @Test(expected = IllegalArgumentException.class) - public void testParseExpectingIllegalArgumentException() { + public void testParseExpectingIllegalArgumentException1() { FeatureValue.parse("ad_url:"); } - @Test(expected = NumberFormatException.class) - public void testParseExpectingNumberFormatException() { + @Test(expected = IllegalArgumentException.class) + public void testParseExpectingIllegalArgumentException2() { FeatureValue.parse("ad_url:xxxxx"); } From 7df76bf68f450a95038cb689471fb7313ef09387 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Fri, 14 Jul 2017 15:55:58 +0900 Subject: [PATCH 08/13] Added Supress Warnings deprecation to remove compilation warning --- core/src/test/java/hivemall/regression/AdaGradUDTFTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java b/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java index e7a0a8902..fa7e28a7b 100644 --- a/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java +++ b/core/src/test/java/hivemall/regression/AdaGradUDTFTest.java @@ -30,6 +30,7 @@ public class AdaGradUDTFTest { + @SuppressWarnings("deprecation") @Test public void testInitialize() throws UDFArgumentException { AdaGradUDTF udtf = new AdaGradUDTF(); From b28a0da8456c3e28d071dbb9b6139199db193b47 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Fri, 14 Jul 2017 15:56:21 +0900 Subject: [PATCH 09/13] Added annotations. --- .../src/main/java/hivemall/utils/hadoop/HiveUtils.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 85e4be20b..cb2b5e382 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -780,6 +780,7 @@ public static ConstantObjectInspector asConstantObjectInspector( return (ConstantObjectInspector) oi; } + @Nonnull public static PrimitiveObjectInspector asPrimitiveObjectInspector( @Nonnull final ObjectInspector oi) throws UDFArgumentException { if (oi.getCategory() != Category.PRIMITIVE) { @@ -789,6 +790,7 @@ public static PrimitiveObjectInspector asPrimitiveObjectInspector( return (PrimitiveObjectInspector) oi; } + @Nonnull public static StringObjectInspector asStringOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!STRING_TYPE_NAME.equals(argOI.getTypeName())) { @@ -797,6 +799,7 @@ public static StringObjectInspector asStringOI(@Nonnull final ObjectInspector ar return (StringObjectInspector) argOI; } + @Nonnull public static BinaryObjectInspector asBinaryOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!BINARY_TYPE_NAME.equals(argOI.getTypeName())) { @@ -805,6 +808,7 @@ public static BinaryObjectInspector asBinaryOI(@Nonnull final ObjectInspector ar return (BinaryObjectInspector) argOI; } + @Nonnull public static BooleanObjectInspector asBooleanOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!BOOLEAN_TYPE_NAME.equals(argOI.getTypeName())) { @@ -813,6 +817,7 @@ public static BooleanObjectInspector asBooleanOI(@Nonnull final ObjectInspector return (BooleanObjectInspector) argOI; } + @Nonnull public static IntObjectInspector asIntOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!INT_TYPE_NAME.equals(argOI.getTypeName())) { @@ -821,6 +826,7 @@ public static IntObjectInspector asIntOI(@Nonnull final ObjectInspector argOI) return (IntObjectInspector) argOI; } + @Nonnull public static LongObjectInspector asLongOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!BIGINT_TYPE_NAME.equals(argOI.getTypeName())) { @@ -829,6 +835,7 @@ public static LongObjectInspector asLongOI(@Nonnull final ObjectInspector argOI) return (LongObjectInspector) argOI; } + @Nonnull public static DoubleObjectInspector asDoubleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentException { if (!DOUBLE_TYPE_NAME.equals(argOI.getTypeName())) { @@ -837,6 +844,7 @@ public static DoubleObjectInspector asDoubleOI(@Nonnull final ObjectInspector ar return (DoubleObjectInspector) argOI; } + @Nonnull public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { @@ -862,6 +870,7 @@ public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final ObjectIn return oi; } + @Nonnull public static PrimitiveObjectInspector asLongCompatibleOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { @@ -888,6 +897,7 @@ public static PrimitiveObjectInspector asLongCompatibleOI(@Nonnull final ObjectI return oi; } + @Nonnull public static PrimitiveObjectInspector asIntegerOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { From ebbed2d14f5aed1b9d64fce6c7fdadd7aedf13fc Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Fri, 14 Jul 2017 16:06:21 +0900 Subject: [PATCH 10/13] Removed unused import --- core/src/main/java/hivemall/smile/regression/RegressionTree.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java index 5ec27df40..38b7b8347 100755 --- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java +++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java @@ -34,7 +34,6 @@ package hivemall.smile.regression; import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName; -import static hivemall.smile.utils.SmileExtUtils.resolveName; import hivemall.annotations.VisibleForTesting; import hivemall.math.matrix.Matrix; import hivemall.math.matrix.ints.ColumnMajorIntMatrix; From b0f494e4fa3f38cadfc9ac9e6bb59b333ab85665 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Fri, 14 Jul 2017 16:07:52 +0900 Subject: [PATCH 11/13] [Disruptive change] Removed a deprecated functionality to load a prediction model from HDFS --- .../java/hivemall/GeneralLearnerBaseUDTF.java | 8 +- .../main/java/hivemall/LearnerBaseUDTF.java | 143 ------------------ .../BinaryOnlineClassifierUDTF.java | 9 +- .../MulticlassOnlineClassifierUDTF.java | 3 - .../regression/RegressionBaseUDTF.java | 3 - 5 files changed, 7 insertions(+), 159 deletions(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 6bf38ce11..e4adb0558 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -155,9 +155,6 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector : featureInputOI; this.model = createModel(); - if (preloadedModelFile != null) { - loadPredictionModel(model, preloadedModelFile, featureOutputOI); - } try { this.optimizer = createOptimizer(optimizerOptions); @@ -243,7 +240,7 @@ protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector ar } @Nonnull - protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) { + protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector featureOutputOI) { ArrayList fieldNames = new ArrayList(); ArrayList fieldOIs = new ArrayList(); @@ -364,7 +361,8 @@ private static FeatureValue readFeatureValue(@Nonnull final ByteBuffer buf, feature = Long.valueOf(featureStr); break; default: - throw new IllegalStateException("Unexpected feature type " + featureType + " for feature: " + featureStr); + throw new IllegalStateException("Unexpected feature type " + featureType + + " for feature: " + featureStr); } double value = buf.getDouble(); return new FeatureValue(feature, value); diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index fdb22f881..18ef23b9c 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -18,7 +18,6 @@ */ package hivemall; -import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector; import hivemall.mix.MixMessage.MixEventName; import hivemall.mix.client.MixClient; import hivemall.model.DenseModel; @@ -29,22 +28,14 @@ import hivemall.model.SpaceEfficientDenseModel; import hivemall.model.SparseModel; import hivemall.model.SynchronizedModelWrapper; -import hivemall.model.WeightValue; -import hivemall.model.WeightValue.WeightValueWithCovar; import hivemall.optimizer.DenseOptimizerFactory; import hivemall.optimizer.Optimizer; import hivemall.optimizer.SparseOptimizerFactory; -import hivemall.utils.datetime.StopWatch; -import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.Primitives; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.util.List; import java.util.Map; import javax.annotation.CheckForNull; @@ -57,21 +48,12 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.serde2.SerDeException; -import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.StructField; -import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector; -import org.apache.hadoop.io.Text; public abstract class LearnerBaseUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class); protected final boolean enableNewModel; - protected String preloadedModelFile; protected boolean dense_model; protected int model_dims; protected boolean disable_halffloat; @@ -97,7 +79,6 @@ protected boolean useCovariance() { @Override protected Options getOptions() { Options opts = new Options(); - opts.addOption("loadmodel", true, "Model file name in the distributed cache"); opts.addOption("dense", "densemodel", false, "Use dense model or not"); opts.addOption("dims", "feature_dimensions", true, "The dimension of model [default: 16777216 (2^24)]"); @@ -119,7 +100,6 @@ protected Options getOptions() { @Override protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException { - String modelfile = null; boolean denseModel = false; int modelDims = -1; boolean disableHalfFloat = false; @@ -135,8 +115,6 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) String rawArgs = HiveUtils.getConstString(argOIs[2]); cl = parseOptions(rawArgs); - modelfile = cl.getOptionValue("loadmodel"); - denseModel = cl.hasOption("dense"); if (denseModel) { modelDims = Primitives.parseInt(cl.getOptionValue("dims"), 16777216); @@ -160,7 +138,6 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) ssl = cl.hasOption("ssl"); } - this.preloadedModelFile = modelfile; this.dense_model = denseModel; this.model_dims = modelDims; this.disable_halffloat = disableHalfFloat; @@ -272,126 +249,6 @@ protected int getInitialModelSize() { return 16384; } - protected void loadPredictionModel(PredictionModel model, String filename, - PrimitiveObjectInspector keyOI) { - final StopWatch elapsed = new StopWatch(); - final long lines; - try { - if (useCovariance()) { - lines = loadPredictionModel(model, new File(filename), keyOI, - writableFloatObjectInspector, writableFloatObjectInspector); - } else { - lines = loadPredictionModel(model, new File(filename), keyOI, - writableFloatObjectInspector); - } - } catch (IOException e) { - throw new RuntimeException("Failed to load a model: " + filename, e); - } catch (SerDeException e) { - throw new RuntimeException("Failed to load a model: " + filename, e); - } - if (model.size() > 0) { - logger.info("Loaded " + model.size() + " features from distributed cache '" + filename - + "' (" + lines + " lines) in " + elapsed); - } - } - - private static long loadPredictionModel(PredictionModel model, File file, - PrimitiveObjectInspector keyOI, WritableFloatObjectInspector valueOI) - throws IOException, SerDeException { - long count = 0L; - if (!file.exists()) { - return count; - } - if (!file.getName().endsWith(".crc")) { - if (file.isDirectory()) { - for (File f : file.listFiles()) { - count += loadPredictionModel(model, f, keyOI, valueOI); - } - } else { - LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI); - StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector(); - StructField keyRef = lineOI.getStructFieldRef("key"); - StructField valueRef = lineOI.getStructFieldRef("value"); - PrimitiveObjectInspector keyRefOI = (PrimitiveObjectInspector) keyRef.getFieldObjectInspector(); - FloatObjectInspector varRefOI = (FloatObjectInspector) valueRef.getFieldObjectInspector(); - - BufferedReader reader = null; - try { - reader = HadoopUtils.getBufferedReader(file); - String line; - while ((line = reader.readLine()) != null) { - count++; - Text lineText = new Text(line); - Object lineObj = serde.deserialize(lineText); - List fields = lineOI.getStructFieldsDataAsList(lineObj); - Object f0 = fields.get(0); - Object f1 = fields.get(1); - if (f0 == null || f1 == null) { - continue; // avoid the case that key or value is null - } - Object k = keyRefOI.getPrimitiveWritableObject(keyRefOI.copyObject(f0)); - float v = varRefOI.get(f1); - model.set(k, new WeightValue(v, false)); - } - } finally { - IOUtils.closeQuietly(reader); - } - } - } - return count; - } - - private static long loadPredictionModel(PredictionModel model, File file, - PrimitiveObjectInspector featureOI, WritableFloatObjectInspector weightOI, - WritableFloatObjectInspector covarOI) throws IOException, SerDeException { - long count = 0L; - if (!file.exists()) { - return count; - } - if (!file.getName().endsWith(".crc")) { - if (file.isDirectory()) { - for (File f : file.listFiles()) { - count += loadPredictionModel(model, f, featureOI, weightOI, covarOI); - } - } else { - LazySimpleSerDe serde = HiveUtils.getLineSerde(featureOI, weightOI, covarOI); - StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector(); - StructField c1ref = lineOI.getStructFieldRef("c1"); - StructField c2ref = lineOI.getStructFieldRef("c2"); - StructField c3ref = lineOI.getStructFieldRef("c3"); - PrimitiveObjectInspector c1oi = (PrimitiveObjectInspector) c1ref.getFieldObjectInspector(); - FloatObjectInspector c2oi = (FloatObjectInspector) c2ref.getFieldObjectInspector(); - FloatObjectInspector c3oi = (FloatObjectInspector) c3ref.getFieldObjectInspector(); - - BufferedReader reader = null; - try { - reader = HadoopUtils.getBufferedReader(file); - String line; - while ((line = reader.readLine()) != null) { - count++; - Text lineText = new Text(line); - Object lineObj = serde.deserialize(lineText); - List fields = lineOI.getStructFieldsDataAsList(lineObj); - Object f0 = fields.get(0); - Object f1 = fields.get(1); - Object f2 = fields.get(2); - if (f0 == null || f1 == null) { - continue; // avoid unexpected case - } - Object k = c1oi.getPrimitiveWritableObject(c1oi.copyObject(f0)); - float v = c2oi.get(f1); - float cov = (f2 == null) ? WeightValueWithCovar.DEFAULT_COVAR - : c3oi.get(f2); - model.set(k, new WeightValueWithCovar(v, cov, false)); - } - } finally { - IOUtils.closeQuietly(reader); - } - } - } - return count; - } - @Override public void close() throws HiveException { if (mixClient != null) { diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index 2dcf521cc..b1acd73e4 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -88,16 +88,14 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector : featureInputOI; this.model = createModel(); - if (preloadedModelFile != null) { - loadPredictionModel(model, preloadedModelFile, featureOutputOI); - } this.count = 0; this.sampled = 0; return getReturnOI(featureOutputOI); } - protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + @Nonnull + protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); @@ -106,7 +104,8 @@ protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } - protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) { + @Nonnull + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureRawOI) { ArrayList fieldNames = new ArrayList(); ArrayList fieldOIs = new ArrayList(); diff --git a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java index af8545c32..8355ad32b 100644 --- a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java @@ -106,9 +106,6 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector : featureInputOI; this.label2model = new HashMap(64); - if (preloadedModelFile != null) { - loadPredictionModel(label2model, preloadedModelFile, labelInputOI, featureOutputOI); - } this.count = 0; return getReturnOI(labelInputOI, featureOutputOI); diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java index f8fae89e8..e802f2d34 100644 --- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java +++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java @@ -92,9 +92,6 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector : featureInputOI; this.model = createModel(); - if (preloadedModelFile != null) { - loadPredictionModel(model, preloadedModelFile, featureOutputOI); - } this.count = 0; this.sampled = 0; From 32684b1b9b763aaf3f665f1bf07deceedd4fc1e0 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Fri, 14 Jul 2017 17:59:26 +0900 Subject: [PATCH 12/13] Fixed output OI resolution scheme --- .../java/hivemall/GeneralLearnerBaseUDTF.java | 74 +++++++++++++------ .../main/java/hivemall/LearnerBaseUDTF.java | 20 +++++ .../BinaryOnlineClassifierUDTF.java | 11 +-- .../MulticlassOnlineClassifierUDTF.java | 17 ++--- .../regression/RegressionBaseUDTF.java | 15 ++-- 5 files changed, 90 insertions(+), 47 deletions(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index e4adb0558..669b2130e 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -77,12 +77,7 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); - public enum FeatureType { - STRING, INT, LONG - } - private ListObjectInspector featureListOI; - private PrimitiveObjectInspector featureInputOI; private PrimitiveObjectInspector targetOI; private FeatureType featureType; @@ -147,13 +142,12 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu throw new UDFArgumentException( "_FUNC_ takes 2 arguments: List features, float target [, constant string options]"); } - this.featureInputOI = processFeaturesOI(argOIs[0]); + this.featureListOI = HiveUtils.asListOI(argOIs[0]); + this.featureType = getFeatureType(featureListOI); this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]); processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.model = createModel(); try { @@ -165,7 +159,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu this.count = 0L; this.sampled = 0; - return getReturnOI(featureOutputOI); + return getReturnOI(getFeatureOutputOI(featureType)); } @Override @@ -220,33 +214,67 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen return cl; } + public enum FeatureType { + STRING, INT, LONG + } + @Nonnull - protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) + private static FeatureType getFeatureType(@Nonnull ListObjectInspector featureListOI) throws UDFArgumentException { - this.featureListOI = (ListObjectInspector) arg; - ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); - if (featureRawOI instanceof StringObjectInspector) { - this.featureType = FeatureType.STRING; - } else if (featureRawOI instanceof IntObjectInspector) { - this.featureType = FeatureType.INT; - } else if (featureRawOI instanceof LongObjectInspector) { - this.featureType = FeatureType.LONG; + final ObjectInspector featureOI = featureListOI.getListElementObjectInspector(); + if (featureOI instanceof StringObjectInspector) { + return FeatureType.STRING; + } else if (featureOI instanceof IntObjectInspector) { + return FeatureType.INT; + } else if (featureOI instanceof LongObjectInspector) { + return FeatureType.LONG; } else { throw new UDFArgumentException("Feature object inspector must be one of " + "[StringObjectInspector, IntObjectInspector, LongObjectInspector]: " - + featureRawOI.toString()); + + featureOI.toString()); + } + } + + @Nonnull + protected final ObjectInspector getFeatureOutputOI(@Nonnull final FeatureType featureType) + throws UDFArgumentException { + final PrimitiveObjectInspector outputOI; + if (dense_model) { + switch (featureType) { + case INT: + case LONG: + outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel (long is also parsed as int) + break; + default: + throw new UDFArgumentException( + "Only INT or BIGINT is allowed for the element of feature vector when -densemodel option is specified: " + + featureType); + } + } else { + switch (featureType) { + case STRING: + outputOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + break; + case INT: + outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; + break; + case LONG: + outputOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector; + break; + default: + throw new IllegalStateException("Unexpected feature type: " + featureType); + } } - return HiveUtils.asPrimitiveObjectInspector(featureRawOI); + return outputOI; } @Nonnull - protected StructObjectInspector getReturnOI(@Nonnull PrimitiveObjectInspector featureOutputOI) { + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) { ArrayList fieldNames = new ArrayList(); ArrayList fieldOIs = new ArrayList(); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureOutputOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index 18ef23b9c..630e929b9 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -18,6 +18,8 @@ */ package hivemall; +import static hivemall.HivemallConstants.BIGINT_TYPE_NAME; +import static hivemall.HivemallConstants.INT_TYPE_NAME; import hivemall.mix.MixMessage.MixEventName; import hivemall.mix.client.MixClient; import hivemall.model.DenseModel; @@ -49,6 +51,9 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; public abstract class LearnerBaseUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class); @@ -249,6 +254,21 @@ protected int getInitialModelSize() { return 16384; } + @Nonnull + protected ObjectInspector getFeatureOutputOI(@Nonnull PrimitiveObjectInspector featureInputOI) + throws UDFArgumentException { + if (dense_model) { + final String typeName = featureInputOI.getTypeName(); + if (INT_TYPE_NAME.equals(typeName) || BIGINT_TYPE_NAME.equals(typeName)) { + return PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel + } + throw new UDFArgumentException( + "Only INT or BIGINT is allowed for the element of feature vector when -densemodel option is specified: " + + typeName); + } + return ObjectInspectorUtils.getStandardObjectInspector(featureInputOI); + } + @Override public void close() throws HiveException { if (mixClient != null) { diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java index b1acd73e4..2f4db3ad3 100644 --- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -85,13 +85,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.model = createModel(); - this.count = 0; this.sampled = 0; - return getReturnOI(featureOutputOI); + + return getReturnOI(getFeatureOutputOI(featureInputOI)); } @Nonnull @@ -105,13 +103,12 @@ protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector ar } @Nonnull - protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureRawOI) { + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) { ArrayList fieldNames = new ArrayList(); ArrayList fieldOIs = new ArrayList(); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { diff --git a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java index 8355ad32b..08a040b57 100644 --- a/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java +++ b/core/src/main/java/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java @@ -103,12 +103,10 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.label2model = new HashMap(64); - this.count = 0; - return getReturnOI(labelInputOI, featureOutputOI); + + return getReturnOI(labelInputOI, getFeatureOutputOI(featureInputOI)); } @Override @@ -116,7 +114,8 @@ protected int getInitialModelSize() { return 8192; } - protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + @Nonnull + protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); @@ -130,8 +129,9 @@ protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } - protected StructObjectInspector getReturnOI(ObjectInspector labelRawOI, - ObjectInspector featureRawOI) { + @Nonnull + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector labelRawOI, + @Nonnull ObjectInspector featureOutputOI) { ArrayList fieldNames = new ArrayList(); ArrayList fieldOIs = new ArrayList(); @@ -139,8 +139,7 @@ protected StructObjectInspector getReturnOI(ObjectInspector labelRawOI, ObjectInspector labelOI = ObjectInspectorUtils.getStandardObjectInspector(labelRawOI); fieldOIs.add(labelOI); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java index e802f2d34..33196ab69 100644 --- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java +++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java @@ -89,16 +89,15 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu processOptions(argOIs); - PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector - : featureInputOI; this.model = createModel(); - this.count = 0; this.sampled = 0; - return getReturnOI(featureOutputOI); + + return getReturnOI(getFeatureOutputOI(featureInputOI)); } - protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) + @Nonnull + protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg) throws UDFArgumentException { this.featureListOI = (ListObjectInspector) arg; ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector(); @@ -107,13 +106,13 @@ protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg) return HiveUtils.asPrimitiveObjectInspector(featureRawOI); } - protected StructObjectInspector getReturnOI(ObjectInspector featureOutputOI) { + @Nonnull + protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) { ArrayList fieldNames = new ArrayList(); ArrayList fieldOIs = new ArrayList(); fieldNames.add("feature"); - ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureOutputOI); - fieldOIs.add(featureOI); + fieldOIs.add(featureOutputOI); fieldNames.add("weight"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); if (useCovariance()) { From 902ca08efda5310ffd14be8cbffc3dccb7e6f4b6 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Fri, 14 Jul 2017 19:19:50 +0900 Subject: [PATCH 13/13] Fixed to accept String element in feature vector even when -densemodel is specified --- .../main/java/hivemall/GeneralLearnerBaseUDTF.java | 12 ++---------- core/src/main/java/hivemall/LearnerBaseUDTF.java | 11 ++--------- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 669b2130e..f1bc0458f 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -240,16 +240,8 @@ protected final ObjectInspector getFeatureOutputOI(@Nonnull final FeatureType fe throws UDFArgumentException { final PrimitiveObjectInspector outputOI; if (dense_model) { - switch (featureType) { - case INT: - case LONG: - outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel (long is also parsed as int) - break; - default: - throw new UDFArgumentException( - "Only INT or BIGINT is allowed for the element of feature vector when -densemodel option is specified: " - + featureType); - } + // TODO validation + outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel (long/string is also parsed as int) } else { switch (featureType) { case STRING: diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java index 630e929b9..b9ec668cb 100644 --- a/core/src/main/java/hivemall/LearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java @@ -18,8 +18,6 @@ */ package hivemall; -import static hivemall.HivemallConstants.BIGINT_TYPE_NAME; -import static hivemall.HivemallConstants.INT_TYPE_NAME; import hivemall.mix.MixMessage.MixEventName; import hivemall.mix.client.MixClient; import hivemall.model.DenseModel; @@ -258,13 +256,8 @@ protected int getInitialModelSize() { protected ObjectInspector getFeatureOutputOI(@Nonnull PrimitiveObjectInspector featureInputOI) throws UDFArgumentException { if (dense_model) { - final String typeName = featureInputOI.getTypeName(); - if (INT_TYPE_NAME.equals(typeName) || BIGINT_TYPE_NAME.equals(typeName)) { - return PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel - } - throw new UDFArgumentException( - "Only INT or BIGINT is allowed for the element of feature vector when -densemodel option is specified: " - + typeName); + // TODO validation + return PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel } return ObjectInspectorUtils.getStandardObjectInspector(featureInputOI); }