Skip to content
Permalink
Browse files
[HIVEMALL-276] Stable support for XGBoost v0.90
## What changes were proposed in this pull request?

- Fix xgboost module to create DMatrix from CSRMatrix
- Support xgboost v0.90 hyperparameters
- Replace xgboost4j with [xgboost-predictor](https://github.com/komiya-atsushi/xgboost-predictor-java) for prediction
- Add documentation about Xgboost

## What type of PR is it?

Refactoring, Improvement

## What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-276
https://issues.apache.org/jira/browse/HIVEMALL-275
https://issues.apache.org/jira/browse/HIVEMALL-279
https://issues.apache.org/jira/browse/HIVEMALL-272
https://issues.apache.org/jira/browse/HIVEMALL-27

## How to use this feature?

as described in [user guide](http://hivemall.apache.org/userguide/index.html).

## How was this patch tested?

unit tests and manual tests on EMR

## Checklist

- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?

Author: Makoto Yui <myui@apache.org>

Closes #213 from myui/HIVEMALL-275-2.
  • Loading branch information
myui committed Nov 22, 2019
1 parent 86409b4 commit 54e1d7da67690b17640809e60c7b251fdc82308a
Showing 48 changed files with 4,098 additions and 1,123 deletions.
@@ -28,6 +28,8 @@
import java.util.Collections;
import java.util.List;

import javax.annotation.CheckForNull;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

@@ -171,6 +173,31 @@ protected final List<FeatureValue> parseFeatures(@Nonnull final List<?> features
return list;
}

protected static <T> T checkNotNull(@CheckForNull final T arg, @Nonnegative final String errMsg)
throws UDFArgumentException {
if (arg == null) {
throw new UDFArgumentException(errMsg);
}
return arg;
}

protected static <T> T checkNotNull(@CheckForNull final T arg, @Nonnegative final int index)
throws UDFArgumentException {
if (arg == null) {
throw new UDFArgumentException(String.format("%d-th argument MUST not be null", index));
}
return arg;
}

protected static Object nonNullArgument(@Nonnull final Object[] args,
@Nonnegative final int index) throws UDFArgumentException {
final Object arg = args[index];
if (arg == null) {
throw new UDFArgumentException(String.format("%d-th argument MUST not be null", index));
}
return arg;
}

/**
* Raise {@link UDFArgumentException} if the given condition is false.
*
@@ -60,7 +60,7 @@ protected void checkLossFunction(@Nonnull LossFunction lossFunction)
@Override
protected void checkTargetValue(final float label) throws UDFArgumentException {
if (label != -1 && label != 0 && label != 1) {
throw new UDFArgumentException("Invalid label value for classification: + label");
throw new UDFArgumentException("Invalid label value for classification: " + label);
}
}

@@ -40,6 +40,7 @@
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -55,7 +56,6 @@
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableLongObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

// @formatter:off
@@ -130,7 +130,7 @@ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws Hive
outputOI = internalMergeOI();
} else {// terminate
outputOI = ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
}
return outputOI;
}
@@ -222,7 +222,7 @@ public void merge(@SuppressWarnings("deprecation") AggregationBuffer aggr, Objec
}

@Override
public List<FloatWritable> terminate(
public List<DoubleWritable> terminate(
@SuppressWarnings("deprecation") AggregationBuffer aggr) throws HiveException {
ArrayAvgAggregationBuffer myAggr = (ArrayAvgAggregationBuffer) aggr;

@@ -234,11 +234,11 @@ public List<FloatWritable> terminate(
final double[] sum = myAggr._sum;
final long[] count = myAggr._count;

final FloatWritable[] ary = new FloatWritable[size];
final DoubleWritable[] ary = new DoubleWritable[size];
for (int i = 0; i < size; i++) {
long c = count[i];
float avg = (c == 0) ? 0.f : (float) (sum[i] / c);
ary[i] = new FloatWritable(avg);
ary[i] = new DoubleWritable(avg);
}
return Arrays.asList(ary);
}
@@ -115,6 +115,17 @@ public static int parseInt(@Nonnull final Object o) {
return Integer.parseInt(s);
}

@Nullable
public static Writable copyToWritable(@Nullable final Object o,
@Nonnull final PrimitiveObjectInspector poi) {
if (o == null) {
return null;
}
Object copied = poi.copyObject(o);
Object result = poi.getPrimitiveWritableObject(copied);
return (Writable) result;
}

public static Text asText(@Nullable final Object o) {
if (o == null) {
return null;
@@ -562,6 +573,7 @@ public static double[] getConstDoubleArray(@Nonnull final ObjectInspector oi)
return ary;
}

@Nullable
public static String getConstString(@Nonnull final ObjectInspector oi)
throws UDFArgumentException {
if (!isStringOI(oi)) {
@@ -572,6 +584,19 @@ public static String getConstString(@Nonnull final ObjectInspector oi)
return v == null ? null : v.toString();
}


@Nullable
public static String getConstString(@Nonnull final ObjectInspector[] argOIs, final int argIndex)
throws UDFArgumentException {
final ObjectInspector oi = getObjectInspector(argOIs, argIndex);
if (!isStringOI(oi)) {
throw new UDFArgumentException("argOIs[" + argIndex + "] must be a Text value: "
+ TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
}
Text v = getConstValue(oi);
return v == null ? null : v.toString();
}

public static boolean getConstBoolean(@Nonnull final ObjectInspector oi)
throws UDFArgumentException {
if (!isBooleanOI(oi)) {
@@ -882,6 +907,14 @@ public static ConstantObjectInspector asConstantObjectInspector(
return (ConstantObjectInspector) oi;
}

public static ObjectInspector getObjectInspector(@Nonnull final ObjectInspector[] argOIs,
final int argIndex) throws UDFArgumentException {
if (argIndex >= argOIs.length) {
throw new UDFArgumentException("Illegal argument index:" + argIndex);
}
return argOIs[argIndex];
}

@Nonnull
public static PrimitiveObjectInspector asPrimitiveObjectInspector(
@Nonnull final ObjectInspector oi) throws UDFArgumentException {
@@ -892,15 +925,38 @@ public static PrimitiveObjectInspector asPrimitiveObjectInspector(
return (PrimitiveObjectInspector) oi;
}

@Nonnull
public static PrimitiveObjectInspector asPrimitiveObjectInspector(
@Nonnull final ObjectInspector[] argOIs, final int argIndex)
throws UDFArgumentException {
final ObjectInspector oi = getObjectInspector(argOIs, argIndex);
if (oi.getCategory() != Category.PRIMITIVE) {
throw new UDFArgumentException("Expecting PrimitiveObjectInspector for argOIs["
+ argIndex + "] but got " + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
}
return (PrimitiveObjectInspector) oi;
}

@Nonnull
public static StringObjectInspector asStringOI(@Nonnull final ObjectInspector argOI)
throws UDFArgumentException {
if (!STRING_TYPE_NAME.equals(argOI.getTypeName())) {
if (!isStringOI(argOI)) {
throw new UDFArgumentException("Argument type must be String: " + argOI.getTypeName());
}
return (StringObjectInspector) argOI;
}

@Nonnull
public static StringObjectInspector asStringOI(@Nonnull final ObjectInspector[] argOIs,
final int argIndex) throws UDFArgumentException {
final ObjectInspector oi = getObjectInspector(argOIs, argIndex);
if (!isStringOI(oi)) {
throw new UDFArgumentException(
"argOIs[" + argIndex + "] type must be String: " + oi.getTypeName());
}
return (StringObjectInspector) oi;
}

@Nonnull
public static BinaryObjectInspector asBinaryOI(@Nonnull final ObjectInspector argOI)
throws UDFArgumentException {
@@ -1047,6 +1103,31 @@ public static PrimitiveObjectInspector asDoubleCompatibleOI(
return oi;
}

@Nonnull
public static PrimitiveObjectInspector asDoubleCompatibleOI(
@Nonnull final ObjectInspector[] argOIs, final int argIndex)
throws UDFArgumentException {
final PrimitiveObjectInspector oi = asPrimitiveObjectInspector(argOIs, argIndex);
switch (oi.getPrimitiveCategory()) {
case BYTE:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case DECIMAL:
case STRING:
case TIMESTAMP:
break;
default:
throw new UDFArgumentTypeException(argIndex,
"Only numeric or string type arguments are accepted but " + oi.getTypeName()
+ " is passed for argument index " + argIndex);
}
return oi;

}

@Nonnull
public static PrimitiveObjectInspector asFloatingPointOI(@Nonnull final ObjectInspector argOI)
throws UDFArgumentTypeException {
@@ -1101,6 +1182,19 @@ public static ListObjectInspector asListOI(@Nonnull final ObjectInspector oi)
return (ListObjectInspector) oi;
}

@Nonnull
public static ListObjectInspector asListOI(@Nonnull final ObjectInspector[] argOIs,
final int argIndex) throws UDFArgumentException {
final ObjectInspector oi = getObjectInspector(argOIs, argIndex);
Category category = oi.getCategory();
if (category != Category.LIST) {
throw new UDFArgumentException("Expecting ListObjectInspector for argOIs[" + argIndex
+ "] but got " + TypeInfoUtils.getTypeInfoFromObjectInspector(oi));
}
return (ListObjectInspector) oi;
}


@Nonnull
public static MapObjectInspector asMapOI(@Nonnull final ObjectInspector oi)
throws UDFArgumentException {
@@ -90,6 +90,30 @@ public static List<DoubleWritable> newDoubleList(final int size) {
return newDoubleList(size, 0.d);
}

@Nonnull
public static List<FloatWritable> newFloatList(final int size) {
return newFloatList(size, 0.f);
}

@Nonnull
public static List<FloatWritable> newFloatList(final int size, final float defaultValue) {
// workaround to avoid a bug in Kryo
// https://issues.apache.org/jira/browse/HIVE-12551
/*
final DoubleWritable[] array = new DoubleWritable[size];
for (int i = 0; i < size; i++) {
array[i] = new DoubleWritable(defaultValue);
}
return Arrays.asList(array);
*/
final List<FloatWritable> list = new ArrayList<FloatWritable>(size);
for (int i = 0; i < size; i++) {
list.add(new FloatWritable(defaultValue));
}
return list;
}


@Nonnull
public static List<DoubleWritable> newDoubleList(final int size, final double defaultValue) {
// workaround to avoid a bug in Kryo
@@ -171,6 +195,20 @@ public static List<DoubleWritable> toWritableList(@Nonnull final double[] src) {
return list;
}

@Nonnull
public static List<FloatWritable> toWritableList(@Nonnull final float[] src,
@Nullable List<FloatWritable> list) throws UDFArgumentException {
if (list == null) {
return toWritableList(src);
}

Preconditions.checkArgument(src.length == list.size(), UDFArgumentException.class);
for (int i = 0; i < src.length; i++) {
list.set(i, new FloatWritable(src[i]));
}
return list;
}

@Nonnull
public static List<DoubleWritable> toWritableList(@Nonnull final double[] src,
@Nullable List<DoubleWritable> list) throws UDFArgumentException {
@@ -185,6 +223,15 @@ public static List<DoubleWritable> toWritableList(@Nonnull final double[] src,
return list;
}

@Nonnull
public static void setValues(@Nonnull final float[] src,
@Nonnull final List<FloatWritable> list) throws UDFArgumentException {
Preconditions.checkArgument(src.length == list.size(), UDFArgumentException.class);
for (int i = 0; i < src.length; i++) {
list.set(i, new FloatWritable(src[i]));
}
}

public static Text val(final String v) {
return new Text(v);
}
@@ -19,6 +19,7 @@
package hivemall.utils.io;

import hivemall.utils.codec.ZigZagLEB128Codec;
import hivemall.utils.io.CompressionStreamFactory.CompressionAlgorithm;

import java.io.BufferedReader;
import java.io.Closeable;
@@ -351,4 +352,49 @@ public static void readFully(final InputStream in, final byte[] b) throws IOExce
readFully(in, b, 0, b.length);
}

@Nonnull
public static byte[] toCompressedText(@Nonnull final byte[] in) throws IOException {
return toCompressedText(in, in.length);
}

@Nonnull
public static byte[] toCompressedText(@Nonnull final byte[] in, final int len)
throws IOException {
final FastByteArrayInputStream fis = new FastByteArrayInputStream(in, len);
final FastMultiByteArrayOutputStream fos = new FastMultiByteArrayOutputStream();

FinishableOutputStream dos = null;
try {
Base91OutputStream bos = new Base91OutputStream(fos);
dos = CompressionStreamFactory.createOutputStream(bos, CompressionAlgorithm.deflate);
copy(fis, dos);
dos.finish(); // flush is called
return fos.toByteArray_clear();
} finally {
IOUtils.closeQuietly(dos);
}
}

@Nonnull
public static byte[] fromCompressedText(@Nonnull final byte[] src) throws IOException {
return fromCompressedText(src, src.length);
}

@Nonnull
public static byte[] fromCompressedText(@Nonnull final byte[] src, final int len)
throws IOException {
final FastByteArrayInputStream bis = new FastByteArrayInputStream(src, len);
final FastMultiByteArrayOutputStream bos = new FastMultiByteArrayOutputStream();

InputStream compressedStream = null;
try {
compressedStream = CompressionStreamFactory.createInputStream(
new Base91InputStream(bis), CompressionAlgorithm.deflate);
copy(compressedStream, bos);
return bos.toByteArray_clear();
} finally {
IOUtils.closeQuietly(compressedStream);
}
}

}

0 comments on commit 54e1d7d

Please sign in to comment.