Skip to content
Permalink
Browse files
[HIVEMALL-273] Support xgboost v0.90
## What changes were proposed in this pull request?

Support xgboost v0.90

## What type of PR is it?

Improvement

## What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-273

## How was this patch tested?

unit tests and manual tests on EMR

## How to use this feature?

https://gist.github.com/myui/aa6e142a95ca8f995cc8e49146dbe2eb

## Checklist

- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?

Author: Makoto Yui <myui@apache.org>

Closes #209 from myui/HIVEMALL-273.
  • Loading branch information
myui committed Oct 30, 2019
1 parent 695ad30 commit 97cbc0e6fbfaf706722bea6bec434a8a9b6321bd
Showing 8 changed files with 112 additions and 6 deletions.
@@ -157,8 +157,8 @@ public List<IntWritable> evaluate(DeferredObject[] arguments) throws HiveExcepti
* @param stop exclusive index of the end
* @param step positive interval value
*/
private static IntWritable[] range(final int start, final int stop,
@Nonnegative final int step) throws UDFArgumentException {
private static IntWritable[] range(final int start, final int stop, @Nonnegative final int step)
throws UDFArgumentException {
if (step <= 0) {
throw new UDFArgumentException("Invalid step value: " + step);
}
@@ -1152,8 +1152,7 @@ public static int[] range(final int start, final int stop, @Nonnegative final in
return r;
}

public static int divideAndRoundUp(@Nonnegative final int num,
@Nonnegative final int divisor) {
public static int divideAndRoundUp(@Nonnegative final int num, @Nonnegative final int divisor) {
return (num + divisor - 1) / divisor;
}

@@ -33,3 +33,6 @@ create temporary function xgboost_predict as 'hivemall.xgboost.tools.XGBoostPred

drop temporary function if exists xgboost_multiclass_predict;
create temporary function xgboost_multiclass_predict as 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF';

drop temporary function if exists xgboost_version;
create temporary function xgboost_version as 'hivemall.xgboost.XGBoostVersionUDF';
@@ -902,3 +902,5 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U
DROP FUNCTION xgboost_multiclass_predict;
CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}';

DROP FUNCTION IF EXISTS xgboost_version;
CREATE FUNCTION xgboost_version as 'hivemall.xgboost.XGBoostVersionUDF' USING JAR '${hivemall_jar}';
@@ -32,7 +32,7 @@

<properties>
<main.basedir>${project.parent.basedir}</main.basedir>
<xgboost.version>0.7-rc5</xgboost.version>
<xgboost.version>0.90-rc1</xgboost.version>
</properties>

<dependencies>
@@ -18,10 +18,16 @@
*/
package hivemall.xgboost;

import ml.dmlc.xgboost4j.LabeledPoint;

import java.io.IOException;
import java.io.InputStream;
import java.util.Properties;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import ml.dmlc.xgboost4j.LabeledPoint;
import org.apache.hadoop.hive.ql.metadata.HiveException;

public final class XGBoostUtils {

@@ -54,4 +60,17 @@ public static LabeledPoint parseFeatures(final double target,
return new LabeledPoint((float) target, indices, values);
}

@Nonnull
public static String getVersion() throws HiveException {
Properties props = new Properties();
try (InputStream versionResourceFile =
Thread.currentThread().getContextClassLoader().getResourceAsStream(
"xgboost4j-version.properties")) {
props.load(versionResourceFile);
} catch (IOException e) {
throw new HiveException("Failed to load xgboost4j-version.properties", e);
}
return props.getProperty("version", "<unknown>");
}

}
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.xgboost;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;

@Description(name = "xgboost_version", value = "_FUNC_() - Returns the version of xgboost",
extended = "SELECT xgboost_version();")
@UDFType(deterministic = true, stateful = false)
public final class XGBoostVersionUDF extends UDF {

@Nullable
private String version;

@Nonnull
public String evaluate() throws HiveException {
if (version == null) {
this.version = XGBoostUtils.getVersion();
}
return version;
}

}
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.xgboost;


import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.junit.Assert;
import org.junit.Test;

public class XGBoostVersionUDFTest {

@Test
public void test() throws HiveException {
XGBoostVersionUDF udf = new XGBoostVersionUDF();
String v1 = udf.evaluate();
Assert.assertNotNull(v1);
Assert.assertNotEquals("<unknown>", v1);
String v2 = udf.evaluate();
Assert.assertSame(v1, v2);
}

}

0 comments on commit 97cbc0e

Please sign in to comment.