diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java index 2b723ed60..f2ff560b7 100644 --- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java +++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java @@ -49,6 +49,7 @@ import hivemall.smile.utils.SmileExtUtils; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.lang.ObjectUtils; +import hivemall.utils.lang.StringUtils; import hivemall.utils.lang.mutable.MutableInt; import hivemall.utils.sampling.IntReservoirSampler; @@ -56,7 +57,9 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.PriorityQueue; import javax.annotation.Nonnull; @@ -426,6 +429,58 @@ public void exportGraphviz(@Nonnull final StringBuilder builder, } } + @Deprecated + public int opCodegen(@Nonnull final List scripts, int depth) { + int selfDepth = 0; + final StringBuilder buf = new StringBuilder(); + if (trueChild == null && falseChild == null) { + buf.append("push ").append(output); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("goto last"); + scripts.add(buf.toString()); + selfDepth += 2; + } else { + if (splitFeatureType == AttributeType.NOMINAL) { + buf.append("push ").append("x[").append(splitFeature).append("]"); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("push ").append(splitValue); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("ifeq "); + scripts.add(buf.toString()); + depth += 3; + selfDepth += 3; + int trueDepth = trueChild.opCodegen(scripts, depth); + selfDepth += trueDepth; + scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth)); + int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); + selfDepth += falseDepth; + } else if (splitFeatureType == AttributeType.NUMERIC) { + buf.append("push ").append("x[").append(splitFeature).append("]"); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("push ").append(splitValue); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("ifle "); + scripts.add(buf.toString()); + depth += 3; + selfDepth += 3; + int trueDepth = trueChild.opCodegen(scripts, depth); + selfDepth += trueDepth; + scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth)); + int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); + selfDepth += falseDepth; + } else { + throw new IllegalStateException("Unsupported attribute type: " + + splitFeatureType); + } + } + return selfDepth; + } + @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(splitFeature); @@ -1069,6 +1124,7 @@ public int predict(Vector x, double[] posteriori) { throw new UnsupportedOperationException("Not supported."); } + @Nonnull public String predictJsCodegen(@Nonnull final String[] featureNames, @Nonnull final String[] classNames) { StringBuilder buf = new StringBuilder(1024); @@ -1076,6 +1132,16 @@ public String predictJsCodegen(@Nonnull final String[] featureNames, return buf.toString(); } + @Deprecated + @Nonnull + public String predictOpCodegen(@Nonnull final String sep) { + List opslist = new ArrayList(); + _root.opCodegen(opslist, 0); + opslist.add("call end"); + String scripts = StringUtils.concat(opslist, sep); + return scripts; + } + @Nonnull public byte[] serialize(boolean compress) throws HiveException { try { diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java index 38b7b8347..06708763b 100755 --- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java +++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java @@ -49,6 +49,7 @@ import hivemall.utils.collections.sets.IntArraySet; import hivemall.utils.collections.sets.IntSet; import hivemall.utils.lang.ObjectUtils; +import hivemall.utils.lang.StringUtils; import hivemall.utils.lang.mutable.MutableInt; import hivemall.utils.math.MathUtils; @@ -56,7 +57,9 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.PriorityQueue; import javax.annotation.Nonnull; @@ -373,6 +376,57 @@ public void exportGraphviz(@Nonnull final StringBuilder builder, } } + @Deprecated + public int opCodegen(@Nonnull final List scripts, int depth) { + int selfDepth = 0; + final StringBuilder buf = new StringBuilder(); + if (trueChild == null && falseChild == null) { + buf.append("push ").append(output); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("goto last"); + scripts.add(buf.toString()); + selfDepth += 2; + } else { + if (splitFeatureType == AttributeType.NOMINAL) { + buf.append("push ").append("x[").append(splitFeature).append("]"); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("push ").append(splitValue); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("ifeq "); + scripts.add(buf.toString()); + depth += 3; + selfDepth += 3; + int trueDepth = trueChild.opCodegen(scripts, depth); + selfDepth += trueDepth; + scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth)); + int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); + selfDepth += falseDepth; + } else if (splitFeatureType == AttributeType.NUMERIC) { + buf.append("push ").append("x[").append(splitFeature).append("]"); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("push ").append(splitValue); + scripts.add(buf.toString()); + buf.setLength(0); + buf.append("ifle "); + scripts.add(buf.toString()); + depth += 3; + selfDepth += 3; + int trueDepth = trueChild.opCodegen(scripts, depth); + selfDepth += trueDepth; + scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth)); + int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); + selfDepth += falseDepth; + } else { + throw new IllegalStateException("Unsupported attribute type: " + + splitFeatureType); + } + } + return selfDepth; + } @Override public void writeExternal(ObjectOutput out) throws IOException { @@ -932,12 +986,23 @@ public double predict(@Nonnull final Vector x) { return _root.predict(x); } + @Nonnull public String predictJsCodegen(@Nonnull final String[] featureNames) { StringBuilder buf = new StringBuilder(1024); _root.exportJavascript(buf, featureNames, 0); return buf.toString(); } + @Deprecated + @Nonnull + public String predictOpCodegen(@Nonnull String sep) { + List opslist = new ArrayList(); + _root.opCodegen(opslist, 0); + opslist.add("call end"); + String scripts = StringUtils.concat(opslist, sep); + return scripts; + } + @Nonnull public byte[] serialize(boolean compress) throws HiveException { try { diff --git a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java index 40957cbde..7f1d1ecdf 100644 --- a/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java +++ b/core/src/main/java/hivemall/smile/tools/RandomForestEnsembleUDAF.java @@ -18,14 +18,19 @@ */ package hivemall.smile.tools; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; +import hivemall.utils.lang.Counter; import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.SizeOf; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Map; +import javax.annotation.CheckForNull; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -44,6 +49,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; @@ -55,7 +61,7 @@ @Description( name = "rf_ensemble", - value = "_FUNC_(int yhat, array proba [, double model_weight=1.0])" + value = "_FUNC_(int yhat [, array proba [, double model_weight=1.0]])" + " - Returns emsebled prediction results in probabilities>") public final class RandomForestEnsembleUDAF extends AbstractGenericUDAFResolver { @@ -64,30 +70,225 @@ public RandomForestEnsembleUDAF() { } @Override - public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { - if (typeInfo.length != 2 && typeInfo.length != 3) { - throw new UDFArgumentLengthException("Expected 2 or 3 arguments but got " - + typeInfo.length); + public GenericUDAFEvaluator getEvaluator(@Nonnull final TypeInfo[] typeInfo) + throws SemanticException { + switch (typeInfo.length) { + case 1: { + if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]); + } + return new RfEvaluatorV1(); + } + case 3: + if (!HiveUtils.isFloatingPointTypeInfo(typeInfo[2])) { + throw new UDFArgumentTypeException(2, + "Expected DOUBLE or FLOAT for model_weight: " + typeInfo[2]); + } + /* fall through */ + case 2: {// typeInfo.length == 2 || typeInfo.length == 3 + if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]); + } + if (!HiveUtils.isFloatingPointListTypeInfo(typeInfo[1])) { + throw new UDFArgumentTypeException(1, + "ARRAY is expected for posteriori: " + typeInfo[1]); + } + return new RfEvaluatorV2(); + } + default: + throw new UDFArgumentLengthException("Expected 1~3 arguments but got " + + typeInfo.length); + } + } + + @Deprecated + public static final class RfEvaluatorV1 extends GenericUDAFEvaluator { + + // original input + private PrimitiveObjectInspector yhatOI; + + // partial aggregation + private StandardMapObjectInspector internalMergeOI; + private IntObjectInspector keyOI; + private IntObjectInspector valueOI; + + public RfEvaluatorV1() { + super(); + } + + @Override + public ObjectInspector init(@Nonnull Mode mode, @Nonnull ObjectInspector[] argOIs) + throws HiveException { + super.init(mode, argOIs); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.yhatOI = HiveUtils.asIntegerOI(argOIs[0]); + } else {// from partial aggregation + this.internalMergeOI = (StandardMapObjectInspector) argOIs[0]; + this.keyOI = HiveUtils.asIntOI(internalMergeOI.getMapKeyObjectInspector()); + this.valueOI = HiveUtils.asIntOI(internalMergeOI.getMapValueObjectInspector()); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector); + } else {// terminate + List fieldNames = new ArrayList<>(3); + List fieldOIs = new ArrayList<>(3); + fieldNames.add("label"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("probability"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("probabilities"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, + fieldOIs); + } + return outputOI; } - if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) { - throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]); + + @Override + public RfAggregationBufferV1 getNewAggregationBuffer() throws HiveException { + RfAggregationBufferV1 buf = new RfAggregationBufferV1(); + buf.reset(); + return buf; } - if (!HiveUtils.isFloatingPointListTypeInfo(typeInfo[1])) { - throw new UDFArgumentTypeException(1, "ARRAY is expected for posteriori: " - + typeInfo[1]); + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg; + buf.reset(); + } + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg; + + Preconditions.checkNotNull(parameters[0]); + int yhat = PrimitiveObjectInspectorUtils.getInt(parameters[0], yhatOI); + + buf.iterate(yhat); } - if (typeInfo.length == 3) { - if (!HiveUtils.isFloatingPointTypeInfo(typeInfo[2])) { - throw new UDFArgumentTypeException(2, "Expected DOUBLE or FLOAT for model_weight: " - + typeInfo[2]); + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg; + + return buf.terminatePartial(); + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + final RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg; + + Map partialResult = internalMergeOI.getMap(partial); + for (Map.Entry entry : partialResult.entrySet()) { + putIntoMap(entry.getKey(), entry.getValue(), buf); } } - return new RfEvaluator(); + + private void putIntoMap(@CheckForNull Object key, @CheckForNull Object value, + @Nonnull RfAggregationBufferV1 dst) { + Preconditions.checkNotNull(key); + Preconditions.checkNotNull(value); + + int k = keyOI.get(key); + int v = valueOI.get(value); + dst.merge(k, v); + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg; + + return buf.terminate(); + } + + } + + public static final class RfAggregationBufferV1 extends AbstractAggregationBuffer { + + @Nonnull + private Counter partial; + + public RfAggregationBufferV1() { + super(); + reset(); + } + + void reset() { + this.partial = new Counter(); + } + + void iterate(final int k) { + partial.increment(k); + } + + @Nonnull + Map terminatePartial() { + return partial.getMap(); + } + + void merge(final int k, final int v) { + partial.increment(Integer.valueOf(k), v); + } + + @Nullable + Object[] terminate() { + final Map counts = partial.getMap(); + + final int size = counts.size(); + if (size == 0) { + return null; + } + + final IntArrayList keyList = new IntArrayList(size); + long totalCnt = 0L; + Integer maxKey = null; + int maxCnt = Integer.MIN_VALUE; + for (Map.Entry e : counts.entrySet()) { + Integer key = e.getKey(); + keyList.add(key); + int cnt = e.getValue().intValue(); + totalCnt += cnt; + if (cnt >= maxCnt) { + maxCnt = cnt; + maxKey = key; + } + } + + final int[] keyArray = keyList.toArray(); + Arrays.sort(keyArray); + int last = keyArray[keyArray.length - 1]; + + double totalCnt_d = (double) totalCnt; + final double[] probabilities = new double[Math.max(2, last + 1)]; + for (int i = 0, len = probabilities.length; i < len; i++) { + final Integer cnt = counts.get(Integer.valueOf(i)); + if (cnt == null) { + probabilities[i] = 0.d; + } else { + probabilities[i] = cnt.intValue() / totalCnt_d; + } + } + + Object[] result = new Object[3]; + result[0] = new IntWritable(maxKey); + double proba = maxCnt / totalCnt_d; + result[1] = new DoubleWritable(proba); + result[2] = WritableUtils.toWritableList(probabilities); + return result; + } + } @SuppressWarnings("deprecation") - public static final class RfEvaluator extends GenericUDAFEvaluator { + public static final class RfEvaluatorV2 extends GenericUDAFEvaluator { private PrimitiveObjectInspector yhatOI; private ListObjectInspector posterioriOI; @@ -100,7 +301,7 @@ public static final class RfEvaluator extends GenericUDAFEvaluator { private IntObjectInspector sizeFieldOI; private StandardListObjectInspector posterioriFieldOI; - public RfEvaluator() { + public RfEvaluatorV2() { super(); } @@ -152,21 +353,21 @@ public ObjectInspector init(@Nonnull Mode mode, @Nonnull ObjectInspector[] param } @Override - public RfAggregationBuffer getNewAggregationBuffer() throws HiveException { - RfAggregationBuffer buf = new RfAggregationBuffer(); + public RfAggregationBufferV2 getNewAggregationBuffer() throws HiveException { + RfAggregationBufferV2 buf = new RfAggregationBufferV2(); reset(buf); return buf; } @Override public void reset(AggregationBuffer agg) throws HiveException { - RfAggregationBuffer buf = (RfAggregationBuffer) agg; + RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg; buf.reset(); } @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { - RfAggregationBuffer buf = (RfAggregationBuffer) agg; + RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg; Preconditions.checkNotNull(parameters[0]); int yhat = PrimitiveObjectInspectorUtils.getInt(parameters[0], yhatOI); @@ -185,7 +386,7 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { - RfAggregationBuffer buf = (RfAggregationBuffer) agg; + RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg; if (buf._k == -1) { return null; } @@ -201,7 +402,7 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial == null) { return; } - RfAggregationBuffer buf = (RfAggregationBuffer) agg; + RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg; Object o1 = internalMergeOI.getStructFieldData(partial, sizeField); int size = sizeFieldOI.get(o1); @@ -220,7 +421,7 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { @Override public Object terminate(AggregationBuffer agg) throws HiveException { - RfAggregationBuffer buf = (RfAggregationBuffer) agg; + RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg; if (buf._k == -1) { return null; } @@ -239,13 +440,13 @@ public Object terminate(AggregationBuffer agg) throws HiveException { } - public static final class RfAggregationBuffer extends AbstractAggregationBuffer { + public static final class RfAggregationBufferV2 extends AbstractAggregationBuffer { @Nullable private double[] _posteriori; private int _k; - public RfAggregationBuffer() { + public RfAggregationBufferV2() { super(); reset(); } diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java new file mode 100644 index 000000000..4053856de --- /dev/null +++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.tools; + +import hivemall.annotations.Since; +import hivemall.smile.classification.DecisionTree; +import hivemall.smile.regression.RegressionTree; +import hivemall.smile.vm.StackMachine; +import hivemall.smile.vm.VMRuntimeException; +import hivemall.utils.codec.Base91; +import hivemall.utils.codec.DeflateCodec; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.IOUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Arrays; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.script.Bindings; +import javax.script.Compilable; +import javax.script.CompiledScript; +import javax.script.ScriptEngine; +import javax.script.ScriptEngineManager; +import javax.script.ScriptException; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.MapredContext; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.JobConf; + +@Description( + name = "tree_predict_v1", + value = "_FUNC_(string modelId, int modelType, string script, array features [, const boolean classification])" + + " - Returns a prediction result of a random forest") +@UDFType(deterministic = true, stateful = false) +@Since(version = "v0.5-rc.1") +@Deprecated +public final class TreePredictUDFv1 extends GenericUDF { + + private boolean classification; + private PrimitiveObjectInspector modelTypeOI; + private StringObjectInspector stringOI; + private ListObjectInspector featureListOI; + private PrimitiveObjectInspector featureElemOI; + + @Nullable + private transient Evaluator evaluator; + private boolean support_javascript_eval = true; + + @Override + public void configure(MapredContext context) { + super.configure(context); + + if (context != null) { + JobConf conf = context.getJobConf(); + String tdJarVersion = conf.get("td.jar.version"); + if (tdJarVersion != null) { + this.support_javascript_eval = false; + } + } + } + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 4 && argOIs.length != 5) { + throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments"); + } + + this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]); + this.stringOI = HiveUtils.asStringOI(argOIs[2]); + ListObjectInspector listOI = HiveUtils.asListOI(argOIs[3]); + this.featureListOI = listOI; + ObjectInspector elemOI = listOI.getListElementObjectInspector(); + this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + + boolean classification = false; + if (argOIs.length == 5) { + classification = HiveUtils.getConstBoolean(argOIs[4]); + } + this.classification = classification; + + if (classification) { + return PrimitiveObjectInspectorFactory.writableIntObjectInspector; + } else { + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + } + + @Override + public Writable evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { + Object arg0 = arguments[0].get(); + if (arg0 == null) { + throw new HiveException("ModelId was null"); + } + // Not using string OI for backward compatibilities + String modelId = arg0.toString(); + + Object arg1 = arguments[1].get(); + int modelTypeId = PrimitiveObjectInspectorUtils.getInt(arg1, modelTypeOI); + ModelType modelType = ModelType.resolve(modelTypeId); + + Object arg2 = arguments[2].get(); + if (arg2 == null) { + return null; + } + Text script = stringOI.getPrimitiveWritableObject(arg2); + + Object arg3 = arguments[3].get(); + if (arg3 == null) { + throw new HiveException("array features was null"); + } + double[] features = HiveUtils.asDoubleArray(arg3, featureListOI, featureElemOI); + + if (evaluator == null) { + this.evaluator = getEvaluator(modelType, support_javascript_eval); + } + + Writable result = evaluator.evaluate(modelId, modelType.isCompressed(), script, features, + classification); + return result; + } + + @Nonnull + private static Evaluator getEvaluator(@Nonnull ModelType type, boolean supportJavascriptEval) + throws UDFArgumentException { + final Evaluator evaluator; + switch (type) { + case serialization: + case serialization_compressed: { + evaluator = new JavaSerializationEvaluator(); + break; + } + case opscode: + case opscode_compressed: { + evaluator = new StackmachineEvaluator(); + break; + } + case javascript: + case javascript_compressed: { + if (!supportJavascriptEval) { + throw new UDFArgumentException( + "Javascript evaluation is not allowed in Treasure Data env"); + } + evaluator = new JavascriptEvaluator(); + break; + } + default: + throw new UDFArgumentException("Unexpected model type was detected: " + type); + } + return evaluator; + } + + @Override + public void close() throws IOException { + this.modelTypeOI = null; + this.stringOI = null; + this.featureElemOI = null; + this.featureListOI = null; + IOUtils.closeQuietly(evaluator); + this.evaluator = null; + } + + @Override + public String getDisplayString(String[] children) { + return "tree_predict(" + Arrays.toString(children) + ")"; + } + + enum ModelType { + + // not compressed + opscode(1, false), javascript(2, false), serialization(3, false), + // compressed + opscode_compressed(-1, true), javascript_compressed(-2, true), + serialization_compressed(-3, true); + + private final int id; + private final boolean compressed; + + private ModelType(int id, boolean compressed) { + this.id = id; + this.compressed = compressed; + } + + int getId() { + return id; + } + + boolean isCompressed() { + return compressed; + } + + @Nonnull + static ModelType resolve(final int id) { + final ModelType type; + switch (id) { + case 1: + type = opscode; + break; + case -1: + type = opscode_compressed; + break; + case 2: + type = javascript; + break; + case -2: + type = javascript_compressed; + break; + case 3: + type = serialization; + break; + case -3: + type = serialization_compressed; + break; + default: + throw new IllegalStateException("Unexpected ID for ModelType: " + id); + } + return type; + } + + } + + public interface Evaluator extends Closeable { + + @Nullable + Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull final Text script, + @Nonnull final double[] features, final boolean classification) + throws HiveException; + + } + + static final class JavaSerializationEvaluator implements Evaluator { + + @Nullable + private String prevModelId = null; + private DecisionTree.Node cNode = null; + private RegressionTree.Node rNode = null; + + JavaSerializationEvaluator() {} + + @Override + public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, + double[] features, boolean classification) throws HiveException { + if (classification) { + return evaluateClassification(modelId, compressed, script, features); + } else { + return evaluteRegression(modelId, compressed, script, features); + } + } + + private IntWritable evaluateClassification(@Nonnull String modelId, boolean compressed, + @Nonnull Text script, double[] features) throws HiveException { + if (!modelId.equals(prevModelId)) { + this.prevModelId = modelId; + int length = script.getLength(); + byte[] b = script.getBytes(); + b = Base91.decode(b, 0, length); + this.cNode = DecisionTree.deserialize(b, b.length, compressed); + } + assert (cNode != null); + int result = cNode.predict(features); + return new IntWritable(result); + } + + private DoubleWritable evaluteRegression(@Nonnull String modelId, boolean compressed, + @Nonnull Text script, double[] features) throws HiveException { + if (!modelId.equals(prevModelId)) { + this.prevModelId = modelId; + int length = script.getLength(); + byte[] b = script.getBytes(); + b = Base91.decode(b, 0, length); + this.rNode = RegressionTree.deserialize(b, b.length, compressed); + } + assert (rNode != null); + double result = rNode.predict(features); + return new DoubleWritable(result); + } + + @Override + public void close() throws IOException {} + + } + + static final class StackmachineEvaluator implements Evaluator { + + private String prevModelId = null; + private StackMachine prevVM = null; + private DeflateCodec codec = null; + + StackmachineEvaluator() {} + + @Override + public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, + double[] features, boolean classification) throws HiveException { + final String scriptStr; + if (compressed) { + if (codec == null) { + this.codec = new DeflateCodec(false, true); + } + byte[] b = script.getBytes(); + int len = script.getLength(); + b = Base91.decode(b, 0, len); + try { + b = codec.decompress(b); + } catch (IOException e) { + throw new HiveException("decompression failed", e); + } + scriptStr = new String(b); + } else { + scriptStr = script.toString(); + } + + final StackMachine vm; + if (modelId.equals(prevModelId)) { + vm = prevVM; + } else { + vm = new StackMachine(); + try { + vm.compile(scriptStr); + } catch (VMRuntimeException e) { + throw new HiveException("failed to compile StackMachine", e); + } + this.prevModelId = modelId; + this.prevVM = vm; + } + + try { + vm.eval(features); + } catch (VMRuntimeException vme) { + throw new HiveException("failed to eval StackMachine", vme); + } catch (Throwable e) { + throw new HiveException("failed to eval StackMachine", e); + } + + Double result = vm.getResult(); + if (result == null) { + return null; + } + if (classification) { + return new IntWritable(result.intValue()); + } else { + return new DoubleWritable(result.doubleValue()); + } + } + + @Override + public void close() throws IOException { + IOUtils.closeQuietly(codec); + } + + } + + static final class JavascriptEvaluator implements Evaluator { + + private final ScriptEngine scriptEngine; + private final Compilable compilableEngine; + + private String prevModelId = null; + private CompiledScript prevCompiled; + + private DeflateCodec codec = null; + + JavascriptEvaluator() throws UDFArgumentException { + ScriptEngineManager manager = new ScriptEngineManager(); + ScriptEngine engine = manager.getEngineByExtension("js"); + if (!(engine instanceof Compilable)) { + throw new UDFArgumentException("ScriptEngine was not compilable: " + + engine.getFactory().getEngineName() + " version " + + engine.getFactory().getEngineVersion()); + } + this.scriptEngine = engine; + this.compilableEngine = (Compilable) engine; + } + + @Override + public Writable evaluate(@Nonnull String modelId, boolean compressed, @Nonnull Text script, + double[] features, boolean classification) throws HiveException { + final String scriptStr; + if (compressed) { + if (codec == null) { + this.codec = new DeflateCodec(false, true); + } + byte[] b = script.getBytes(); + int len = script.getLength(); + b = Base91.decode(b, 0, len); + try { + b = codec.decompress(b); + } catch (IOException e) { + throw new HiveException("decompression failed", e); + } + scriptStr = new String(b); + } else { + scriptStr = script.toString(); + } + + final CompiledScript compiled; + if (modelId.equals(prevModelId)) { + compiled = prevCompiled; + } else { + try { + compiled = compilableEngine.compile(scriptStr); + } catch (ScriptException e) { + throw new HiveException("failed to compile: \n" + script, e); + } + this.prevCompiled = compiled; + } + + final Bindings bindings = scriptEngine.createBindings(); + final Object result; + try { + bindings.put("x", features); + result = compiled.eval(bindings); + } catch (ScriptException se) { + throw new HiveException("failed to evaluate: \n" + script, se); + } catch (Throwable e) { + throw new HiveException("failed to evaluate: \n" + script, e); + } finally { + bindings.clear(); + } + + if (result == null) { + return null; + } + if (!(result instanceof Number)) { + throw new HiveException("Got an unexpected non-number result: " + result); + } + if (classification) { + Number casted = (Number) result; + return new IntWritable(casted.intValue()); + } else { + Number casted = (Number) result; + return new DoubleWritable(casted.doubleValue()); + } + } + + @Override + public void close() throws IOException { + IOUtils.closeQuietly(codec); + } + + } + +} diff --git a/core/src/main/java/hivemall/smile/vm/Operation.java b/core/src/main/java/hivemall/smile/vm/Operation.java new file mode 100644 index 000000000..ba122a524 --- /dev/null +++ b/core/src/main/java/hivemall/smile/vm/Operation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.vm; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +// for tree_predict_v1 +@Deprecated +public final class Operation { + + final OperationEnum op; + final String operand; + + public Operation(@Nonnull OperationEnum op) { + this(op, null); + } + + public Operation(@Nonnull OperationEnum op, @Nullable String operand) { + this.op = op; + this.operand = operand; + } + + public enum OperationEnum { + ADD, SUB, DIV, MUL, DUP, // reserved + PUSH, POP, GOTO, IFEQ, IFEQ2, IFGE, IFGT, IFLE, IFLT, CALL; // used + + static OperationEnum valueOfLowerCase(String op) { + return OperationEnum.valueOf(op.toUpperCase()); + } + } + + @Override + public String toString() { + return op.toString() + (operand != null ? (" " + operand) : ""); + } + +} diff --git a/core/src/main/java/hivemall/smile/vm/StackMachine.java b/core/src/main/java/hivemall/smile/vm/StackMachine.java new file mode 100644 index 000000000..b5168fd2a --- /dev/null +++ b/core/src/main/java/hivemall/smile/vm/StackMachine.java @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.vm; + +import hivemall.utils.lang.StringUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Stack; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +// for tree_predict_v1 +@Deprecated +public final class StackMachine { + public static final String SEP = "; "; + + @Nonnull + private final List code; + @Nonnull + private final Map valuesMap; + @Nonnull + private final Map jumpMap; + @Nonnull + private final Stack programStack; + + /** + * Instruction pointer + */ + private int IP; + + /** + * Stack pointer + */ + @SuppressWarnings("unused") + private int SP; + + private int codeLength; + private boolean[] done; + private Double result; + + public StackMachine() { + this.code = new ArrayList(); + this.valuesMap = new HashMap(); + this.jumpMap = new HashMap(); + this.programStack = new Stack(); + this.SP = 0; + this.result = null; + } + + public void run(@Nonnull String scripts, @Nonnull double[] features) throws VMRuntimeException { + compile(scripts); + eval(features); + } + + public void run(@Nonnull List opslist, @Nonnull double[] features) + throws VMRuntimeException { + compile(opslist); + eval(features); + } + + public void compile(@Nonnull String scripts) throws VMRuntimeException { + List opslist = Arrays.asList(scripts.split(SEP)); + compile(opslist); + } + + public void compile(@Nonnull List opslist) throws VMRuntimeException { + for (String line : opslist) { + String[] ops = line.split(" ", -1); + if (ops.length == 2) { + Operation.OperationEnum o = Operation.OperationEnum.valueOfLowerCase(ops[0]); + code.add(new Operation(o, ops[1])); + } else { + Operation.OperationEnum o = Operation.OperationEnum.valueOfLowerCase(ops[0]); + code.add(new Operation(o)); + } + } + + int size = opslist.size(); + this.codeLength = size - 1; + this.done = new boolean[size]; + } + + public void eval(final double[] features) throws VMRuntimeException { + init(); + bind(features); + execute(0); + } + + private void init() { + valuesMap.clear(); + jumpMap.clear(); + programStack.clear(); + this.SP = 0; + this.result = null; + Arrays.fill(done, false); + } + + private void bind(final double[] features) { + final StringBuilder buf = new StringBuilder(); + for (int i = 0; i < features.length; i++) { + String bindKey = buf.append("x[").append(i).append("]").toString(); + valuesMap.put(bindKey, features[i]); + StringUtils.clear(buf); + } + } + + private void execute(int entryPoint) throws VMRuntimeException { + valuesMap.put("end", -1.0); + jumpMap.put("last", codeLength); + + IP = entryPoint; + + while (IP < code.size()) { + if (done[IP]) { + throw new VMRuntimeException("There is a infinite loop in the Machine code."); + } + done[IP] = true; + Operation currentOperation = code.get(IP); + if (!executeOperation(currentOperation)) { + return; + } + } + } + + @Nullable + public Double getResult() { + return result; + } + + private Double pop() { + SP--; + return programStack.pop(); + } + + private Double push(Double val) { + programStack.push(val); + SP++; + return val; + } + + private boolean executeOperation(Operation currentOperation) throws VMRuntimeException { + if (IP < 0) { + return false; + } + switch (currentOperation.op) { + case GOTO: { + if (StringUtils.isInt(currentOperation.operand)) { + IP = Integer.parseInt(currentOperation.operand); + } else { + IP = jumpMap.get(currentOperation.operand); + } + break; + } + case CALL: { + double candidateIP = valuesMap.get(currentOperation.operand); + if (candidateIP < 0) { + evaluateBuiltinByName(currentOperation.operand); + IP++; + } + break; + } + case IFEQ: { + double a = pop(); + double b = pop(); + if (a == b) { + IP++; + } else { + if (StringUtils.isInt(currentOperation.operand)) { + IP = Integer.parseInt(currentOperation.operand); + } else { + IP = jumpMap.get(currentOperation.operand); + } + } + break; + } + case IFEQ2: {// follow the rule of smile's Math class. + double a = pop(); + double b = pop(); + if (smile.math.Math.equals(a, b)) { + IP++; + } else { + if (StringUtils.isInt(currentOperation.operand)) { + IP = Integer.parseInt(currentOperation.operand); + } else { + IP = jumpMap.get(currentOperation.operand); + } + } + break; + } + case IFGE: { + double lower = pop(); + double upper = pop(); + if (upper >= lower) { + IP++; + } else { + if (StringUtils.isInt(currentOperation.operand)) { + IP = Integer.parseInt(currentOperation.operand); + } else { + IP = jumpMap.get(currentOperation.operand); + } + } + break; + } + case IFGT: { + double lower = pop(); + double upper = pop(); + if (upper > lower) { + IP++; + } else { + if (StringUtils.isInt(currentOperation.operand)) { + IP = Integer.parseInt(currentOperation.operand); + } else { + IP = jumpMap.get(currentOperation.operand); + } + } + break; + } + case IFLE: { + double lower = pop(); + double upper = pop(); + if (upper <= lower) { + IP++; + } else { + if (StringUtils.isInt(currentOperation.operand)) { + IP = Integer.parseInt(currentOperation.operand); + } else { + IP = jumpMap.get(currentOperation.operand); + } + } + break; + } + case IFLT: { + double lower = pop(); + double upper = pop(); + if (upper < lower) { + IP++; + } else { + if (StringUtils.isInt(currentOperation.operand)) { + IP = Integer.parseInt(currentOperation.operand); + } else { + IP = jumpMap.get(currentOperation.operand); + } + } + break; + } + case POP: { + valuesMap.put(currentOperation.operand, pop()); + IP++; + break; + } + case PUSH: { + if (StringUtils.isDouble(currentOperation.operand)) { + push(Double.parseDouble(currentOperation.operand)); + } else { + Double v = valuesMap.get(currentOperation.operand); + if (v == null) { + throw new VMRuntimeException("value is not binded: " + + currentOperation.operand); + } + push(v); + } + IP++; + break; + } + default: + throw new VMRuntimeException("Machine code has wrong opcode :" + + currentOperation.op); + } + return true; + + } + + private void evaluateBuiltinByName(String name) throws VMRuntimeException { + if (name.equals("end")) { + this.result = pop(); + } else { + throw new VMRuntimeException("Machine code has wrong builin function :" + name); + } + } + +} diff --git a/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java b/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java new file mode 100644 index 000000000..360dbecd2 --- /dev/null +++ b/core/src/main/java/hivemall/smile/vm/VMRuntimeException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.vm; + +// for tree_predict_v1 +@Deprecated +public class VMRuntimeException extends Exception { + private static final long serialVersionUID = -7378149197872357802L; + + public VMRuntimeException(String message) { + super(message); + } + + public VMRuntimeException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java new file mode 100644 index 000000000..bf2ac1184 --- /dev/null +++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.tools; + +import static org.junit.Assert.assertEquals; +import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; +import hivemall.smile.classification.DecisionTree; +import hivemall.smile.data.Attribute; +import hivemall.smile.regression.RegressionTree; +import hivemall.smile.tools.TreePredictUDFv1.ModelType; +import hivemall.smile.utils.SmileExtUtils; +import hivemall.smile.vm.StackMachine; +import hivemall.utils.lang.ArrayUtils; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.text.ParseException; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.junit.Test; + +import smile.data.AttributeDataset; +import smile.data.parser.ArffParser; +import smile.math.Math; +import smile.validation.CrossValidation; +import smile.validation.LOOCV; +import smile.validation.RMSE; + +public class TreePredictUDFv1Test { + private static final boolean DEBUG = false; + + /** + * Test of learn method, of class DecisionTree. + */ + @Test + public void testIris() throws IOException, ParseException, HiveException { + URL url = new URL( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(4); + AttributeDataset iris = arffParser.parse(is); + double[][] x = iris.toArray(new double[iris.size()][]); + int[] y = iris.toArray(new int[iris.size()]); + + int n = x.length; + LOOCV loocv = new LOOCV(n); + for (int i = 0; i < n; i++) { + double[][] trainx = Math.slice(x, loocv.train[i]); + int[] trainy = Math.slice(y, loocv.train[i]); + + Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); + DecisionTree tree = new DecisionTree(attrs, new RowMajorDenseMatrix2d(trainx, + x[0].length), trainy, 4); + assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]])); + } + } + + @Test + public void testCpu() throws IOException, ParseException, HiveException { + URL url = new URL( + "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(6); + AttributeDataset data = arffParser.parse(is); + double[] datay = data.toArray(new double[data.size()]); + double[][] datax = data.toArray(new double[data.size()][]); + + int n = datax.length; + int k = 10; + + CrossValidation cv = new CrossValidation(n, k); + for (int i = 0; i < k; i++) { + double[][] trainx = Math.slice(datax, cv.train[i]); + double[] trainy = Math.slice(datay, cv.train[i]); + double[][] testx = Math.slice(datax, cv.test[i]); + + Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); + RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, + trainx[0].length), trainy, 20); + + for (int j = 0; j < testx.length; j++) { + assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0); + } + } + } + + @Test + public void testCpu2() throws IOException, ParseException, HiveException { + URL url = new URL( + "https://gist.githubusercontent.com/myui/ef17aabecf0c0c5bcb69/raw/aac0575b4d43072c6f3c82d9072fdefb61892694/cpu.arff"); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(6); + AttributeDataset data = arffParser.parse(is); + double[] datay = data.toArray(new double[data.size()]); + double[][] datax = data.toArray(new double[data.size()][]); + + int n = datax.length; + int m = 3 * n / 4; + int[] index = Math.permutate(n); + + double[][] trainx = new double[m][]; + double[] trainy = new double[m]; + for (int i = 0; i < m; i++) { + trainx[i] = datax[index[i]]; + trainy[i] = datay[index[i]]; + } + + double[][] testx = new double[n - m][]; + double[] testy = new double[n - m]; + for (int i = m; i < n; i++) { + testx[i - m] = datax[index[i]]; + testy[i - m] = datay[index[i]]; + } + + Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); + RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, + trainx[0].length), trainy, 20); + debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy))); + + for (int i = m; i < n; i++) { + assertEquals(tree.predict(testx[i - m]), evalPredict(tree, testx[i - m]), 1.0); + } + } + + private static double rmse(RegressionTree regression, double[][] x, double[] y) { + final int n = x.length; + final double[] predictions = new double[n]; + for (int i = 0; i < n; i++) { + predictions[i] = regression.predict(x[i]); + } + return new RMSE().measure(y, predictions); + } + + private static int evalPredict(DecisionTree tree, double[] x) throws HiveException, IOException { + String opScript = tree.predictOpCodegen(StackMachine.SEP); + debugPrint(opScript); + + TreePredictUDFv1 udf = new TreePredictUDFv1(); + udf.initialize(new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); + DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), + new DeferredJavaObject(ModelType.opscode.getId()), + new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)), + new DeferredJavaObject(true)}; + + IntWritable result = (IntWritable) udf.evaluate(arguments); + udf.close(); + return result.get(); + } + + private static double evalPredict(RegressionTree tree, double[] x) throws HiveException, + IOException { + String opScript = tree.predictOpCodegen(StackMachine.SEP); + debugPrint(opScript); + + TreePredictUDFv1 udf = new TreePredictUDFv1(); + udf.initialize(new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)}); + DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), + new DeferredJavaObject(ModelType.opscode.getId()), + new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)), + new DeferredJavaObject(false)}; + + DoubleWritable result = (DoubleWritable) udf.evaluate(arguments); + udf.close(); + return result.get(); + } + + private static void debugPrint(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } + +} diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md index 64502b942..29784e06b 100644 --- a/docs/gitbook/binaryclass/titanic_rf.md +++ b/docs/gitbook/binaryclass/titanic_rf.md @@ -198,16 +198,16 @@ FROM ( FROM ( SELECT t.passengerid, - -- hivemall v0.4.1-alpha.2 or before - -- tree_predict(p.model, t.features, ${classification}) as predicted -- hivemall v0.4.1-alpha.3 or later -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT - -- model_id, pred_model + -- hivemall v0.4.1-alpha.3 or later + -- model_id, model_type, pred_model -- hivemall v0.5-rc.1 or later model_id, model_weight, model FROM @@ -222,6 +222,9 @@ FROM ( ; ``` +> #### Caution +> `tree_predict_v1` is for the backward compatibility for using prediction models built before `v0.5-rc.1` on `v0.5-rc.1` or later. + # Kaggle submission ```sql @@ -338,15 +341,15 @@ FROM ( FROM ( SELECT t.passengerid, - -- hivemall v0.4.1-alpha.2 or before - -- tree_predict(p.model, t.features, ${classification}) as predicted -- hivemall v0.4.1-alpha.3 or later -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT + -- hivemall v0.4.1-alpha.3 or later -- model_id, model_type, pred_model -- hivemall v0.5-rc.1 or later model_id, model_weight, model @@ -358,8 +361,7 @@ FROM ( ) t1 group by passengerid -) t2 -; +) t2; create or replace view rf_submit_03 as select diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md index d0e8e8cd2..771c73343 100644 --- a/docs/gitbook/multiclass/iris_randomforest.md +++ b/docs/gitbook/multiclass/iris_randomforest.md @@ -17,6 +17,8 @@ under the License. --> + + # Dataset * https://archive.ics.uci.edu/ml/datasets/Iris @@ -219,13 +221,12 @@ SELECT FROM ( SELECT rowid, - -- hivemall v0.4.1-alpha.2 and before - -- tree_predict(p.model, t.features, ${classification}) as predicted -- hivemall v0.4.1 and later -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM model p LEFT OUTER JOIN -- CROSS JOIN @@ -236,6 +237,9 @@ group by ; ``` +> #### Caution +> `tree_predict_v1` is for the backward compatibility for using prediction models built before `v0.5-rc.1` on `v0.5-rc.1` or later. + ### Parallelize Prediction The following query runs predictions in N-parallel. It would reduce elapsed time for prediction almost by N. @@ -257,15 +261,15 @@ SELECT FROM ( SELECT t.rowid, - -- hivemall v0.4.1-alpha.2 and before - -- tree_predict(p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.4.1 and later -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT + -- hivemall v0.4.1 and later -- model_id, model_type, pred_model -- hivemall v0.5-rc.1 or later model_id, model_weight, model @@ -275,8 +279,7 @@ FROM ( LEFT OUTER JOIN training t ) t1 group by - rowid -; + rowid; ``` # Evaluation diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index a3b672542..c59678a17 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -689,6 +689,10 @@ CREATE FUNCTION train_randomforest_regr as 'hivemall.smile.regression.RandomFore DROP FUNCTION IF EXISTS tree_predict; CREATE FUNCTION tree_predict as 'hivemall.smile.tools.TreePredictUDF' USING JAR '${hivemall_jar}'; +-- for backward compatibility +DROP FUNCTION IF EXISTS tree_predict_v1; +CREATE FUNCTION tree_predict_v1 as 'hivemall.smile.tools.TreePredictUDFv1' USING JAR '${hivemall_jar}'; + DROP FUNCTION IF EXISTS tree_export; CREATE FUNCTION tree_export as 'hivemall.smile.tools.TreeExportUDF' USING JAR '${hivemall_jar}'; diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 77b6a982d..451453528 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -681,6 +681,10 @@ create temporary function train_randomforest_regr as 'hivemall.smile.regression. drop temporary function if exists tree_predict; create temporary function tree_predict as 'hivemall.smile.tools.TreePredictUDF'; +-- for backward compatibility +drop temporary function if exists tree_predict_v1; +create temporary function tree_predict_v1 as 'hivemall.smile.tools.TreePredictUDFv1'; + drop temporary function if exists tree_export; create temporary function tree_export as 'hivemall.smile.tools.TreeExportUDF'; diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 2193cd81a..2cf4d60b6 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -665,6 +665,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION train_randomforest_regr AS 'hivemall.s sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tree_predict") sqlContext.sql("CREATE TEMPORARY FUNCTION tree_predict AS 'hivemall.smile.tools.TreePredictUDF'") +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tree_predict_v1") +sqlContext.sql("CREATE TEMPORARY FUNCTION tree_predict_v1 AS 'hivemall.smile.tools.TreePredictUDFv1'") + sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tree_export") sqlContext.sql("CREATE TEMPORARY FUNCTION tree_export AS 'hivemall.smile.tools.TreeExportUDF'") diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index 7742b0cd1..d1bdfa4eb 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -184,4 +184,4 @@ create temporary function concat_array as 'hivemall.tools.array.ArrayConcatUDF'; create temporary function pa2a_regress as 'hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a'; create temporary function arow_regress as 'hivemall.regression.AROWRegressionUDTF'; create temporary function addBias as 'hivemall.ftvec.AddBiasUDF'; - +create temporary function tree_predict_v1 as 'hivemall.smile.tools.TreePredictUDFv1';