From c742ce58e94913bf446c3b296a24415676f9ac3b Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Thu, 8 Feb 2018 17:36:50 +0900 Subject: [PATCH] [HIVEMALL-172] Change tree_predict 3rd argument to accept string options --- .../hivemall/smile/tools/TreePredictUDF.java | 63 ++++++++++++++----- docs/gitbook/binaryclass/news20_rf.md | 5 +- docs/gitbook/binaryclass/titanic_rf.md | 10 +-- docs/gitbook/multiclass/iris_randomforest.md | 8 ++- 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java index 46b8758c0..ea3bc29e1 100644 --- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java +++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java @@ -18,6 +18,7 @@ */ package hivemall.smile.tools; +import hivemall.UDFWithOptions; import hivemall.math.vector.DenseVector; import hivemall.math.vector.SparseVector; import hivemall.math.vector.Vector; @@ -37,11 +38,12 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -53,12 +55,12 @@ import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; -@Description( - name = "tree_predict", - value = "_FUNC_(string modelId, string model, array features [, const boolean classification])" - + " - Returns a prediction result of a random forest") +@Description(name = "tree_predict", + value = "_FUNC_(string modelId, string model, array features [, const string options | const boolean classification=false])" + + " - Returns a prediction result of a random forest" + + " in posteriori> for classification and for regression") @UDFType(deterministic = true, stateful = false) -public final class TreePredictUDF extends GenericUDF { +public final class TreePredictUDF extends UDFWithOptions { private boolean classification; private StringObjectInspector modelOI; @@ -71,10 +73,26 @@ public final class TreePredictUDF extends GenericUDF { @Nullable private transient Evaluator evaluator; + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("c", "classification", false, + "Predict as classification [default: not enabled]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + + this.classification = cl.hasOption("classification"); + return cl; + } + @Override public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 3 && argOIs.length != 4) { - throw new UDFArgumentException("_FUNC_ takes 3 or 4 arguments"); + throw new UDFArgumentException("tree_predict takes 3 or 4 arguments"); } this.modelOI = HiveUtils.asStringOI(argOIs[1]); @@ -89,15 +107,25 @@ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentEx this.denseInput = false; } else { throw new UDFArgumentException( - "_FUNC_ takes array or array for the second argument: " + "tree_predict takes array or array for the second argument: " + listOI.getTypeName()); } - boolean classification = false; if (argOIs.length == 4) { - classification = HiveUtils.getConstBoolean(argOIs[3]); + ObjectInspector argOI3 = argOIs[3]; + if (HiveUtils.isConstBoolean(argOI3)) { + this.classification = HiveUtils.getConstBoolean(argOI3); + } else if (HiveUtils.isConstString(argOI3)) { + String opts = HiveUtils.getConstString(argOI3); + processOptions(opts); + } else { + throw new UDFArgumentException( + "tree_predict expects or for the fourth argument: " + + argOI3.getTypeName()); + } + } else { + this.classification = false; } - this.classification = classification; if (classification) { List fieldNames = new ArrayList(2); @@ -105,7 +133,8 @@ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentEx fieldNames.add("value"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("posteriori"); - fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } else { return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; @@ -116,7 +145,7 @@ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentEx public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { Object arg0 = arguments[0].get(); if (arg0 == null) { - throw new HiveException("ModelId was null"); + throw new HiveException("modelId should not be null"); } // Not using string OI for backward compatibilities String modelId = arg0.toString(); @@ -134,8 +163,8 @@ public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException this.featuresProbe = parseFeatures(arg2, featuresProbe); if (evaluator == null) { - this.evaluator = classification ? new ClassificationEvaluator() - : new RegressionEvaluator(); + this.evaluator = + classification ? new ClassificationEvaluator() : new RegressionEvaluator(); } return evaluator.evaluate(modelId, model, featuresProbe); } @@ -192,8 +221,8 @@ private Vector parseFeatures(@Nonnull final Object argObj, @Nullable Vector prob } if (feature.indexOf(':') != -1) { - throw new UDFArgumentException("Invaliad feature format `:`: " - + col); + throw new UDFArgumentException( + "Invaliad feature format `:`: " + col); } final int colIndex = Integer.parseInt(feature); diff --git a/docs/gitbook/binaryclass/news20_rf.md b/docs/gitbook/binaryclass/news20_rf.md index fd0b475ac..327939bc7 100644 --- a/docs/gitbook/binaryclass/news20_rf.md +++ b/docs/gitbook/binaryclass/news20_rf.md @@ -47,7 +47,7 @@ from ## Prediction ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; drop table rf_predicted; create table rf_predicted @@ -60,7 +60,8 @@ FROM ( SELECT rowid, m.model_weight, - tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted + tree_predict(m.model_id, m.model, t.features, "-classification") as predicted + -- tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted FROM rf_model m LEFT OUTER JOIN -- CROSS JOIN diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md index 29784e06b..2b5407427 100644 --- a/docs/gitbook/binaryclass/titanic_rf.md +++ b/docs/gitbook/binaryclass/titanic_rf.md @@ -175,7 +175,7 @@ from # Prediction ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; set hive.auto.convert.join=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=16; @@ -202,7 +202,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT @@ -319,7 +320,7 @@ from > [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392] 0.1838351822503962 ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=16; @@ -345,7 +346,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md index b42129743..bfc197f09 100644 --- a/docs/gitbook/multiclass/iris_randomforest.md +++ b/docs/gitbook/multiclass/iris_randomforest.md @@ -206,7 +206,7 @@ from # Prediction ```sql -set hivevar:classification=true; +-- set hivevar:classification=true; set hive.auto.convert.join=true; set hive.mapjoin.optimized.hashtable=false; @@ -225,7 +225,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM model p @@ -265,7 +266,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT