Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
[HIVEMALL-172] Change tree_predict 3rd argument to accept string options
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Feb 8, 2018
1 parent 2958af0 commit c742ce5
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 26 deletions.
63 changes: 46 additions & 17 deletions core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
Expand Up @@ -18,6 +18,7 @@
*/
package hivemall.smile.tools;

import hivemall.UDFWithOptions;
import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
Expand All @@ -37,11 +38,12 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
Expand All @@ -53,12 +55,12 @@
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

@Description(
name = "tree_predict",
value = "_FUNC_(string modelId, string model, array<double|string> features [, const boolean classification])"
+ " - Returns a prediction result of a random forest")
@Description(name = "tree_predict",
value = "_FUNC_(string modelId, string model, array<double|string> features [, const string options | const boolean classification=false])"
+ " - Returns a prediction result of a random forest"
+ " in <int value, array<double> posteriori> for classification and <double> for regression")
@UDFType(deterministic = true, stateful = false)
public final class TreePredictUDF extends GenericUDF {
public final class TreePredictUDF extends UDFWithOptions {

private boolean classification;
private StringObjectInspector modelOI;
Expand All @@ -71,10 +73,26 @@ public final class TreePredictUDF extends GenericUDF {
@Nullable
private transient Evaluator evaluator;

@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("c", "classification", false,
"Predict as classification [default: not enabled]");
return opts;
}

@Override
protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
CommandLine cl = parseOptions(optionValue);

this.classification = cl.hasOption("classification");
return cl;
}

@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length != 3 && argOIs.length != 4) {
throw new UDFArgumentException("_FUNC_ takes 3 or 4 arguments");
throw new UDFArgumentException("tree_predict takes 3 or 4 arguments");
}

this.modelOI = HiveUtils.asStringOI(argOIs[1]);
Expand All @@ -89,23 +107,34 @@ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentEx
this.denseInput = false;
} else {
throw new UDFArgumentException(
"_FUNC_ takes array<double> or array<string> for the second argument: "
"tree_predict takes array<double> or array<string> for the second argument: "
+ listOI.getTypeName());
}

boolean classification = false;
if (argOIs.length == 4) {
classification = HiveUtils.getConstBoolean(argOIs[3]);
ObjectInspector argOI3 = argOIs[3];
if (HiveUtils.isConstBoolean(argOI3)) {
this.classification = HiveUtils.getConstBoolean(argOI3);
} else if (HiveUtils.isConstString(argOI3)) {
String opts = HiveUtils.getConstString(argOI3);
processOptions(opts);
} else {
throw new UDFArgumentException(
"tree_predict expects <const boolean> or <const string> for the fourth argument: "
+ argOI3.getTypeName());
}
} else {
this.classification = false;
}
this.classification = classification;

if (classification) {
List<String> fieldNames = new ArrayList<String>(2);
List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(2);
fieldNames.add("value");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("posteriori");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
} else {
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
Expand All @@ -116,7 +145,7 @@ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentEx
public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
Object arg0 = arguments[0].get();
if (arg0 == null) {
throw new HiveException("ModelId was null");
throw new HiveException("modelId should not be null");
}
// Not using string OI for backward compatibilities
String modelId = arg0.toString();
Expand All @@ -134,8 +163,8 @@ public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException
this.featuresProbe = parseFeatures(arg2, featuresProbe);

if (evaluator == null) {
this.evaluator = classification ? new ClassificationEvaluator()
: new RegressionEvaluator();
this.evaluator =
classification ? new ClassificationEvaluator() : new RegressionEvaluator();
}
return evaluator.evaluate(modelId, model, featuresProbe);
}
Expand Down Expand Up @@ -192,8 +221,8 @@ private Vector parseFeatures(@Nonnull final Object argObj, @Nullable Vector prob
}

if (feature.indexOf(':') != -1) {
throw new UDFArgumentException("Invaliad feature format `<index>:<value>`: "
+ col);
throw new UDFArgumentException(
"Invaliad feature format `<index>:<value>`: " + col);
}

final int colIndex = Integer.parseInt(feature);
Expand Down
5 changes: 3 additions & 2 deletions docs/gitbook/binaryclass/news20_rf.md
Expand Up @@ -47,7 +47,7 @@ from
## Prediction

```sql
SET hivevar:classification=true;
-- SET hivevar:classification=true;

drop table rf_predicted;
create table rf_predicted
Expand All @@ -60,7 +60,8 @@ FROM (
SELECT
rowid,
m.model_weight,
tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted
tree_predict(m.model_id, m.model, t.features, "-classification") as predicted
-- tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted
FROM
rf_model m
LEFT OUTER JOIN -- CROSS JOIN
Expand Down
10 changes: 6 additions & 4 deletions docs/gitbook/binaryclass/titanic_rf.md
Expand Up @@ -175,7 +175,7 @@ from
# Prediction

```sql
SET hivevar:classification=true;
-- SET hivevar:classification=true;
set hive.auto.convert.join=true;
SET hive.mapjoin.optimized.hashtable=false;
SET mapred.reduce.tasks=16;
Expand All @@ -202,7 +202,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
-- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
FROM (
SELECT
Expand Down Expand Up @@ -319,7 +320,7 @@ from
> [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392] 0.1838351822503962
```sql
SET hivevar:classification=true;
-- SET hivevar:classification=true;
SET hive.mapjoin.optimized.hashtable=false;
SET mapred.reduce.tasks=16;

Expand All @@ -345,7 +346,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
-- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
FROM (
SELECT
Expand Down
8 changes: 5 additions & 3 deletions docs/gitbook/multiclass/iris_randomforest.md
Expand Up @@ -206,7 +206,7 @@ from
# Prediction

```sql
set hivevar:classification=true;
-- set hivevar:classification=true;
set hive.auto.convert.join=true;
set hive.mapjoin.optimized.hashtable=false;

Expand All @@ -225,7 +225,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
-- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later
FROM
model p
Expand Down Expand Up @@ -265,7 +266,8 @@ FROM (
-- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted
-- hivemall v0.5-rc.1 or later
p.model_weight,
tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
tree_predict(p.model_id, p.model, t.features, "-classification") as predicted
-- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted
-- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted as predicted -- to use the old model in v0.5-rc.1 or later
FROM (
SELECT
Expand Down

0 comments on commit c742ce5

Please sign in to comment.