From 04c3ce8ed8d8a9c3380c5bad9577dc2e9a8e378a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 16 Mar 2018 15:24:49 +0900 Subject: [PATCH 1/3] Support spark-v2.3 --- spark/pom.xml | 1 + spark/spark-2.3/bin/mvn-zinc | 99 + spark/spark-2.3/extra-src/README.md | 20 + .../org/apache/spark/sql/hive/HiveShim.scala | 279 ++ spark/spark-2.3/pom.xml | 187 ++ .../hivemall/xgboost/XGBoostOptions.scala | 59 + ...pache.spark.sql.sources.DataSourceRegister | 1 + .../src/main/resources/log4j.properties | 29 + .../hivemall/tools/RegressionDatagen.scala | 67 + .../sql/catalyst/expressions/EachTopK.scala | 133 + .../sql/catalyst/plans/logical/JoinTopK.scala | 68 + .../utils/InternalRowPriorityQueue.scala | 76 + .../sql/execution/UserProvidedPlanner.scala | 83 + .../datasources/csv/csvExpressions.scala | 169 ++ .../joins/ShuffledHashJoinTopKExec.scala | 405 +++ .../sql/hive/HivemallGroupedDataset.scala | 636 +++++ .../apache/spark/sql/hive/HivemallOps.scala | 2249 +++++++++++++++++ .../apache/spark/sql/hive/HivemallUtils.scala | 146 ++ .../sql/hive/internal/HivemallOpsImpl.scala | 79 + .../sql/hive/source/XGBoostFileFormat.scala | 163 ++ .../streaming/HivemallStreamingOps.scala | 47 + .../src/test/resources/data/files/README.md | 22 + .../src/test/resources/data/files/complex.seq | 0 .../test/resources/data/files/episodes.avro | 0 .../src/test/resources/data/files/json.txt | 0 .../src/test/resources/data/files/kv1.txt | 0 .../src/test/resources/data/files/kv3.txt | 0 .../src/test/resources/log4j.properties | 24 + .../hivemall/mix/server/MixServerSuite.scala | 124 + .../tools/RegressionDatagenSuite.scala | 33 + .../feature/HivemallLabeledPointSuite.scala | 36 + .../benchmark/BenchmarkBaseAccessor.scala | 23 + .../apache/spark/sql/hive/HiveUdfSuite.scala | 161 ++ .../spark/sql/hive/HivemallOpsSuite.scala | 1393 ++++++++++ .../spark/sql/hive/ModelMixingSuite.scala | 285 +++ .../apache/spark/sql/hive/XGBoostSuite.scala | 151 ++ .../sql/hive/benchmark/MiscBenchmark.scala | 268 ++ .../hive/test/HivemallFeatureQueryTest.scala | 102 + .../spark/sql/test/VectorQueryTest.scala | 88 + .../HivemallOpsWithFeatureSuite.scala | 155 ++ .../org/apache/spark/test/TestUtils.scala | 65 + 41 files changed, 7926 insertions(+) create mode 100755 spark/spark-2.3/bin/mvn-zinc create mode 100644 spark/spark-2.3/extra-src/README.md create mode 100644 spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala create mode 100644 spark/spark-2.3/pom.xml create mode 100644 spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala create mode 100644 spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/spark-2.3/src/main/resources/log4j.properties create mode 100644 spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala create mode 100644 spark/spark-2.3/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala create mode 100644 spark/spark-2.3/src/test/resources/data/files/README.md create mode 100644 spark/spark-2.3/src/test/resources/data/files/complex.seq create mode 100644 spark/spark-2.3/src/test/resources/data/files/episodes.avro create mode 100644 spark/spark-2.3/src/test/resources/data/files/json.txt create mode 100644 spark/spark-2.3/src/test/resources/data/files/kv1.txt create mode 100644 spark/spark-2.3/src/test/resources/data/files/kv3.txt create mode 100644 spark/spark-2.3/src/test/resources/log4j.properties create mode 100644 spark/spark-2.3/src/test/scala/hivemall/mix/server/MixServerSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBaseAccessor.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala create mode 100644 spark/spark-2.3/src/test/scala/org/apache/spark/test/TestUtils.scala diff --git a/spark/pom.xml b/spark/pom.xml index 8279df156..f0827ea33 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -35,6 +35,7 @@ spark-2.0 spark-2.1 spark-2.2 + spark-2.3 diff --git a/spark/spark-2.3/bin/mvn-zinc b/spark/spark-2.3/bin/mvn-zinc new file mode 100755 index 000000000..759b0a56d --- /dev/null +++ b/spark/spark-2.3/bin/mvn-zinc @@ -0,0 +1,99 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyed from commit 48682f6bf663e54cb63b7e95a4520d34b6fa890b in Apache Spark + +# Determine the current working directory +_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# Preserve the calling directory +_CALLING_DIR="$(pwd)" +# Options used during compilation +_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" + +# Installs any application tarball given a URL, the expected tarball name, +# and, optionally, a checkable binary path to determine if the binary has +# already been installed +## Arg1 - URL +## Arg2 - Tarball Name +## Arg3 - Checkable Binary +install_app() { + local remote_tarball="$1/$2" + local local_tarball="${_DIR}/$2" + local binary="${_DIR}/$3" + local curl_opts="--progress-bar -L" + local wget_opts="--progress=bar:force ${wget_opts}" + + if [ -z "$3" -o ! -f "$binary" ]; then + # check if we already have the tarball + # check if we have curl installed + # download application + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ + curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" + # if the file still doesn't exist, lets try `wget` and cross our fingers + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ + wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" + # if both were unsuccessful, exit + [ ! -f "${local_tarball}" ] && \ + echo -n "ERROR: Cannot download $2 with cURL or wget; " && \ + echo "please install manually and try again." && \ + exit 2 + cd "${_DIR}" && tar -xzf "$2" + rm -rf "$local_tarball" + fi +} + +# Install zinc under the bin/ folder +install_zinc() { + local zinc_path="zinc-0.3.9/bin/zinc" + [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + install_app \ + "http://downloads.typesafe.com/zinc/0.3.9" \ + "zinc-0.3.9.tgz" \ + "${zinc_path}" + ZINC_BIN="${_DIR}/${zinc_path}" +} + +# Setup healthy defaults for the Zinc port if none were provided from +# the environment +ZINC_PORT=${ZINC_PORT:-"3030"} + +# Install Zinc for the bin/ +install_zinc + +# Reset the current working directory +cd "${_CALLING_DIR}" + +# Now that zinc is ensured to be installed, check its status and, if its +# not running or just installed, start it +if [ ! -f "${ZINC_BIN}" ]; then + exit -1 +fi +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then + export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} + "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + "${ZINC_BIN}" -start -port ${ZINC_PORT} &>/dev/null +fi + +# Set any `mvn` options if not already present +export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"} + +# Last, call the `mvn` command as usual +mvn -DzincPort=${ZINC_PORT} "$@" diff --git a/spark/spark-2.3/extra-src/README.md b/spark/spark-2.3/extra-src/README.md new file mode 100644 index 000000000..0c622a2e1 --- /dev/null +++ b/spark/spark-2.3/extra-src/README.md @@ -0,0 +1,20 @@ + + +Copyed from the spark v2.3.0 release. diff --git a/spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala new file mode 100644 index 000000000..11afe1af3 --- /dev/null +++ b/spark/spark-2.3/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{InputStream, OutputStream} +import java.rmi.server.UID + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import com.google.common.base.Objects +import org.apache.avro.Schema +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils +import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector +import org.apache.hadoop.io.Writable +import org.apache.hive.com.esotericsoftware.kryo.Kryo +import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.Utils + +private[hive] object HiveShim { + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + val UNLIMITED_DECIMAL_PRECISION = 38 + val UNLIMITED_DECIMAL_SCALE = 18 + val HIVE_GENERIC_UDF_MACRO_CLS = "org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro" + + /* + * This function in hive-0.13 become private, but we have to do this to work around hive bug + */ + private def appendReadColumnNames(conf: Configuration, cols: Seq[String]) { + val old: String = conf.get(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, "") + val result: StringBuilder = new StringBuilder(old) + var first: Boolean = old.isEmpty + + for (col <- cols) { + if (first) { + first = false + } else { + result.append(',') + } + result.append(col) + } + conf.set(ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR, result.toString) + } + + /* + * Cannot use ColumnProjectionUtils.appendReadColumns directly, if ids is null + */ + def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { + if (ids != null) { + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) + } + if (names != null) { + appendReadColumnNames(conf, names) + } + } + + /* + * Bug introduced in hive-0.13. AvroGenericRecordWritable has a member recordReaderID that + * is needed to initialize before serialization. + */ + def prepareWritable(w: Writable, serDeProps: Seq[(String, String)]): Writable = { + w match { + case w: AvroGenericRecordWritable => + w.setRecordReaderID(new UID()) + // In Hive 1.1, the record's schema may need to be initialized manually or a NPE will + // be thrown. + if (w.getFileSchema() == null) { + serDeProps + .find(_._1 == AvroSerdeUtils.AvroTableProperties.SCHEMA_LITERAL.getPropName()) + .foreach { kv => + w.setFileSchema(new Schema.Parser().parse(kv._2)) + } + } + case _ => + } + w + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + if (hdoi.preferWritable()) { + Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue, + hdoi.precision(), hdoi.scale()) + } else { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } + } + + /** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + * @param instance optional UDF instance which contains additional information (for macro) + */ + private[hive] case class HiveFunctionWrapper(var functionClassName: String, + private var instance: AnyRef = null) extends java.io.Externalizable { + + // for Serialization + def this() = this(null) + + override def hashCode(): Int = { + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + Objects.hashCode(functionClassName, instance.asInstanceOf[GenericUDFMacro].getBody()) + } else { + functionClassName.hashCode() + } + } + + override def equals(other: Any): Boolean = other match { + case a: HiveFunctionWrapper if functionClassName == a.functionClassName => + // In case of udf macro, check to make sure they point to the same underlying UDF + if (functionClassName == HIVE_GENERIC_UDF_MACRO_CLS) { + a.instance.asInstanceOf[GenericUDFMacro].getBody() == + instance.asInstanceOf[GenericUDFMacro].getBody() + } else { + true + } + case _ => false + } + + @transient + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp, clazz).asInstanceOf[T] + inp.close() + t + } + + @transient + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) + } + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.readFully(functionInBytes) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + Utils.getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = Utils.getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + implicit def wrapperToFileSinkDesc(w: ShimFileSinkDesc): FileSinkDesc = { + val f = new FileSinkDesc(new Path(w.dir), w.tableInfo, w.compressed) + f.setCompressCodec(w.compressCodec) + f.setCompressType(w.compressType) + f.setTableInfo(w.tableInfo) + f.setDestTableId(w.destTableId) + f + } + + /* + * Bug introduced in hive-0.13. FileSinkDesc is serializable, but its member path is not. + * Fix it through wrapper. + */ + private[hive] class ShimFileSinkDesc( + var dir: String, + var tableInfo: TableDesc, + var compressed: Boolean) + extends Serializable with Logging { + var compressCodec: String = _ + var compressType: String = _ + var destTableId: Int = _ + + def setCompressed(compressed: Boolean) { + this.compressed = compressed + } + + def getDirName(): String = dir + + def setDestTableId(destTableId: Int) { + this.destTableId = destTableId + } + + def setTableInfo(tableInfo: TableDesc) { + this.tableInfo = tableInfo + } + + def setCompressCodec(intermediateCompressorCodec: String) { + compressCodec = intermediateCompressorCodec + } + + def setCompressType(intermediateCompressType: String) { + compressType = intermediateCompressType + } + } +} diff --git a/spark/spark-2.3/pom.xml b/spark/spark-2.3/pom.xml new file mode 100644 index 000000000..cfa64579d --- /dev/null +++ b/spark/spark-2.3/pom.xml @@ -0,0 +1,187 @@ + + + 4.0.0 + + + org.apache.hivemall + hivemall-spark + 0.5.1-incubating-SNAPSHOT + ../pom.xml + + + hivemall-spark2.3 + Hivemall on Spark 2.3 + jar + + + ${project.parent.parent.basedir} + 2.3.0 + 2.3 + 2.6.5 + -ea -Xms768m -Xmx2g -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=512m -XX:ReservedCodeCacheSize=512m + 1.8 + 1.8 + + + + + + org.apache.hivemall + hivemall-core + compile + + + org.apache.hivemall + hivemall-xgboost + compile + + + org.apache.hivemall + hivemall-spark-common + ${project.version} + compile + + + + + org.scala-lang + scala-library + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-hive_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${spark.version} + provided + + + + + org.apache.hivemall + hivemall-mixserv + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark.version} + test-jar + test + + + org.apache.spark + spark-hive_${scala.binary.version} + ${spark.version} + test-jar + test + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + + org.scalatest + scalatest-maven-plugin + + + test + + test + + + + + + org.scalatest + scalatest-maven-plugin + + + ${env.JAVA8_HOME} + ${env.JAVA8_HOME}/bin:${env.PATH} + + + + + + diff --git a/spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala b/spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala new file mode 100644 index 000000000..3e0f274aa --- /dev/null +++ b/spark/spark-2.3/src/main/java/hivemall/xgboost/XGBoostOptions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.xgboost + +import scala.collection.mutable + +import org.apache.commons.cli.Options +import org.apache.spark.annotation.AlphaComponent + +/** + * :: AlphaComponent :: + * An utility class to generate a sequence of options used in XGBoost. + */ +@AlphaComponent +case class XGBoostOptions() { + private val params: mutable.Map[String, String] = mutable.Map.empty + private val options: Options = { + new XGBoostUDTF() { + def options(): Options = super.getOptions() + }.options() + } + + private def isValidKey(key: String): Boolean = { + // TODO: Is there another way to handle all the XGBoost options? + options.hasOption(key) || key == "num_class" + } + + def set(key: String, value: String): XGBoostOptions = { + require(isValidKey(key), s"non-existing key detected in XGBoost options: ${key}") + params.put(key, value) + this + } + + def help(): Unit = { + import scala.collection.JavaConversions._ + options.getOptions.map { case option => println(option) } + } + + override def toString(): String = { + params.map { case (key, value) => s"-$key $value" }.mkString(" ") + } +} diff --git a/spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 000000000..b49e20a2a --- /dev/null +++ b/spark/spark-2.3/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.source.XGBoostFileFormat diff --git a/spark/spark-2.3/src/main/resources/log4j.properties b/spark/spark-2.3/src/main/resources/log4j.properties new file mode 100644 index 000000000..ef4f6063b --- /dev/null +++ b/spark/spark-2.3/src/main/resources/log4j.properties @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=INFO +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala b/spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala new file mode 100644 index 000000000..a2b7f600a --- /dev/null +++ b/spark/spark-2.3/src/main/scala/hivemall/tools/RegressionDatagen.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.tools + +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.types._ + +object RegressionDatagen { + + /** + * Generate data for regression/classification. + * See [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]] + * for the details of arguments below. + */ + def exec(sc: SQLContext, + n_partitions: Int = 2, + min_examples: Int = 1000, + n_features: Int = 10, + n_dims: Int = 200, + seed: Int = 43, + dense: Boolean = false, + prob_one: Float = 0.6f, + sort: Boolean = false, + cl: Boolean = false): DataFrame = { + + require(n_partitions > 0, "Non-negative #n_partitions required.") + require(min_examples > 0, "Non-negative #min_examples required.") + require(n_features > 0, "Non-negative #n_features required.") + require(n_dims > 0, "Non-negative #n_dims required.") + + // Calculate #examples to generate in each partition + val n_examples = (min_examples + n_partitions - 1) / n_partitions + + val df = sc.createDataFrame( + sc.sparkContext.parallelize((0 until n_partitions).map(Row(_)), n_partitions), + StructType( + StructField("data", IntegerType, true) :: + Nil) + ) + import sc.implicits._ + df.lr_datagen( + lit(s"-n_examples $n_examples -n_features $n_features -n_dims $n_dims -prob_one $prob_one" + + (if (dense) " -dense" else "") + + (if (sort) " -sort" else "") + + (if (cl) " -cl" else "")) + ).select($"label".cast(DoubleType).as("label"), $"features") + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala new file mode 100644 index 000000000..cac2a5dcd --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/expressions/EachTopK.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue +import org.apache.spark.sql.types._ + +trait TopKHelper { + + def k: Int + def scoreType: DataType + + @transient val ScoreTypes = TypeCollection( + ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType + ) + + protected case class ScoreWriter(writer: UnsafeRowWriter, ordinal: Int) { + + def write(v: Any): Unit = scoreType match { + case ByteType => writer.write(ordinal, v.asInstanceOf[Byte]) + case ShortType => writer.write(ordinal, v.asInstanceOf[Short]) + case IntegerType => writer.write(ordinal, v.asInstanceOf[Int]) + case LongType => writer.write(ordinal, v.asInstanceOf[Long]) + case FloatType => writer.write(ordinal, v.asInstanceOf[Float]) + case DoubleType => writer.write(ordinal, v.asInstanceOf[Double]) + case d: DecimalType => writer.write(ordinal, v.asInstanceOf[Decimal], d.precision, d.scale) + } + } + + protected lazy val scoreOrdering = { + val ordering = TypeUtils.getInterpretedOrdering(scoreType) + if (k > 0) ordering else ordering.reverse + } + + protected lazy val reverseScoreOrdering = scoreOrdering.reverse + + protected lazy val queue: InternalRowPriorityQueue = { + new InternalRowPriorityQueue(Math.abs(k), (x: Any, y: Any) => scoreOrdering.compare(x, y)) + } +} + +case class EachTopK( + k: Int, + scoreExpr: Expression, + groupExprs: Seq[Expression], + elementSchema: StructType, + children: Seq[Attribute]) + extends Generator with TopKHelper with CodegenFallback { + + override val scoreType: DataType = scoreExpr.dataType + + private lazy val groupingProjection: UnsafeProjection = UnsafeProjection.create(groupExprs) + private lazy val scoreProjection: UnsafeProjection = UnsafeProjection.create(scoreExpr :: Nil) + + // The grouping key of the current partition + private var currentGroupingKeys: UnsafeRow = _ + + override def checkInputDataTypes(): TypeCheckResult = { + if (!ScoreTypes.acceptsType(scoreExpr.dataType)) { + TypeCheckResult.TypeCheckFailure(s"$scoreExpr must have a comparable type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private def topKRowsForGroup(): Seq[InternalRow] = if (queue.size > 0) { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } + val topKRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(topKRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + outputRows.zip(rankNum.map(_._1)).map { case ((_, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + topKRow.setTotalSize(bufferHolder.totalSize()) + new JoinedRow(topKRow, row) + } + } else { + Seq.empty + } + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val groupingKeys = groupingProjection(input) + val ret = if (currentGroupingKeys != groupingKeys) { + val topKRows = topKRowsForGroup() + currentGroupingKeys = groupingKeys.copy() + queue.clear() + topKRows + } else { + Iterator.empty + } + queue += Tuple2(scoreProjection(input).get(0, scoreType), input) + ret + } + + override def terminate(): TraversableOnce[InternalRow] = { + if (queue.size > 0) { + val topKRows = topKRowsForGroup() + queue.clear() + topKRows + } else { + Iterator.empty + } + } + + // TODO: Need to support codegen + // protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala new file mode 100644 index 000000000..556cdc3dd --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/JoinTopK.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +case class JoinTopK( + k: Int, + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression])( + val scoreExpr: NamedExpression, + private[sql] val rankAttr: Seq[Attribute] = AttributeReference("rank", IntegerType)() :: Nil) + extends BinaryNode with PredicateHelper { + + override def output: Seq[Attribute] = joinType match { + case Inner => rankAttr ++ Seq(scoreExpr.toAttribute) ++ left.output ++ right.output + } + + override def references: AttributeSet = { + AttributeSet((expressions ++ Seq(scoreExpr)).flatMap(_.references)) + } + + override protected def validConstraints: Set[Expression] = joinType match { + case Inner if condition.isDefined => + left.constraints.union(right.constraints) + .union(splitConjunctivePredicates(condition.get).toSet) + } + + override protected final def otherCopyArgs: Seq[AnyRef] = { + scoreExpr :: rankAttr :: Nil + } + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + lazy val resolvedExceptNatural: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + duplicateResolved && + condition.forall(_.dataType == BooleanType) + } + + override lazy val resolved: Boolean = joinType match { + case Inner => resolvedExceptNatural + case tpe => throw new AnalysisException(s"Unsupported using join type $tpe") + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala new file mode 100644 index 000000000..12c20fbbb --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/catalyst/utils/InternalRowPriorityQueue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.utils + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} + +import scala.collection.JavaConverters._ +import scala.collection.generic.Growable + +import org.apache.spark.sql.catalyst.InternalRow + +private[sql] class InternalRowPriorityQueue( + maxSize: Int, + compareFunc: (Any, Any) => Int + ) extends Iterable[(Any, InternalRow)] with Growable[(Any, InternalRow)] with Serializable { + + private[this] val ordering = new Ordering[(Any, InternalRow)] { + override def compare(x: (Any, InternalRow), y: (Any, InternalRow)): Int = + compareFunc(x._1, y._1) + } + + private val underlying = new JPriorityQueue[(Any, InternalRow)](maxSize, ordering) + + override def iterator: Iterator[(Any, InternalRow)] = underlying.iterator.asScala + + override def size: Int = underlying.size + + override def ++=(xs: TraversableOnce[(Any, InternalRow)]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: (Any, InternalRow)): this.type = { + if (size < maxSize) { + underlying.offer((elem._1, elem._2.copy())) + } else { + maybeReplaceLowest(elem) + } + this + } + + override def +=(elem1: (Any, InternalRow), elem2: (Any, InternalRow), elems: (Any, InternalRow)*) + : this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: (Any, InternalRow)): Boolean = { + val head = underlying.peek() + if (head != null && ordering.gt(a, head)) { + underlying.poll() + underlying.offer((a._1, a._2.copy())) + } else { + false + } + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala new file mode 100644 index 000000000..09d60a645 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/UserProvidedPlanner.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.{JoinTopK, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf + +private object ExtractJoinTopKKeys extends Logging with PredicateHelper { + /** (k, scoreExpr, joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ + type ReturnType = + (Int, NamedExpression, Seq[Attribute], JoinType, Seq[Expression], Seq[Expression], + Option[Expression], LogicalPlan, LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case join @ JoinTopK(k, left, right, joinType, condition) => + logDebug(s"Considering join on: $condition") + val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) + val joinKeys = predicates.flatMap { + case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) + case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) + // Replace null with default value for joining key, then those rows with null in it could + // be joined together + case EqualNullSafe(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => + Some((Coalesce(Seq(l, Literal.default(l.dataType))), + Coalesce(Seq(r, Literal.default(r.dataType))))) + case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => + Some((Coalesce(Seq(r, Literal.default(r.dataType))), + Coalesce(Seq(l, Literal.default(l.dataType))))) + case other => None + } + val otherPredicates = predicates.filterNot { + case EqualTo(l, r) => + canEvaluate(l, left) && canEvaluate(r, right) || + canEvaluate(l, right) && canEvaluate(r, left) + case other => false + } + + if (joinKeys.nonEmpty) { + val (leftKeys, rightKeys) = joinKeys.unzip + logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") + Some((k, join.scoreExpr, join.rankAttr, joinType, leftKeys, rightKeys, + otherPredicates.reduceOption(And), left, right)) + } else { + None + } + + case p => + None + } +} + +private[sql] class UserProvidedPlanner(val conf: SQLConf) extends Strategy { + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ExtractJoinTopKKeys( + k, scoreExpr, rankAttr, _, leftKeys, rightKeys, condition, left, right) => + Seq(joins.ShuffledHashJoinTopKExec( + k, leftKeys, rightKeys, condition, planLater(left), planLater(right))(scoreExpr, rankAttr)) + case _ => + Nil + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala new file mode 100644 index 000000000..1f56c906e --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import com.univocity.parsers.csv.CsvWriter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Converts a csv input string to a [[StructType]] with the specified schema. + * + * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+ + */ +case class CsvToStruct( + schema: StructType, + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + + def this(schema: StructType, options: Map[String, String], child: Expression) = + this(schema, options, child, None) + + override def nullable: Boolean = true + + @transient private lazy val csvOptions = new CSVOptions(options, timeZoneId.get) + @transient private lazy val csvParser = new UnivocityParser(schema, schema, csvOptions) + + private def parse(input: String): InternalRow = csvParser.parse(input) + + override def dataType: DataType = schema + + override def nullSafeEval(csv: Any): Any = { + try parse(csv.toString) catch { case _: RuntimeException => null } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) +} + +private class CsvGenerator(schema: StructType, options: CSVOptions) { + + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private val valueConverters: Array[ValueConverter] = + schema.map(_.dataType).map(makeConverter).toArray + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString + } + + def convertRow(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = options.nullValue + } + i += 1 + } + values + } +} + +/** + * Converts a [[StructType]] to a csv output string. + */ +case class StructToCsv( + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + private lazy val params = new CSVOptions(options, timeZoneId.get) + + @transient + private lazy val dataSchema = child.dataType.asInstanceOf[StructType] + + @transient + private lazy val writer = new CsvGenerator(dataSchema, params) + + override def dataType: DataType = StringType + + private def verifySchema(schema: StructType): Unit = { + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") + } + + schema.foreach(field => verifyType(field.dataType)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + override def nullSafeEval(row: Any): Any = { + val rowStr = writer.convertRow(row.asInstanceOf[InternalRow]) + .mkString(params.delimiter.toString) + UTF8String.fromString(rowStr) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala new file mode 100644 index 000000000..f628b78be --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinTopKExec.scala @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.utils.InternalRowPriorityQueue +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric._ +import org.apache.spark.sql.types._ + +abstract class PriorityQueueShim { + + def insert(score: Any, row: InternalRow): Unit + def get(): Iterator[InternalRow] + def clear(): Unit +} + +case class ShuffledHashJoinTopKExec( + k: Int, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan)( + scoreExpr: NamedExpression, + rankAttr: Seq[Attribute]) + extends BinaryExecNode with TopKHelper with HashJoin with CodegenSupport { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override val scoreType: DataType = scoreExpr.dataType + override val joinType: JoinType = Inner + override val buildSide: BuildSide = BuildRight // Only support `BuildRight` + + private lazy val scoreProjection: UnsafeProjection = + UnsafeProjection.create(scoreExpr :: Nil, left.output ++ right.output) + + private lazy val boundCondition = if (condition.isDefined) { + (r: InternalRow) => newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval(r) + } else { + (r: InternalRow) => true + } + + private lazy val topKAttr = rankAttr :+ scoreExpr.toAttribute + + private lazy val _priorityQueue = new PriorityQueueShim { + + private val q: InternalRowPriorityQueue = queue + private val joinedRow = new JoinedRow + + override def insert(score: Any, row: InternalRow): Unit = { + q += Tuple2(score, row) + } + + override def get(): Iterator[InternalRow] = { + val outputRows = queue.iterator.toSeq.reverse + val (headScore, _) = outputRows.head + val rankNum = outputRows.scanLeft((1, headScore)) { case ((rank, prevScore), (score, _)) => + if (prevScore == score) (rank, score) else (rank + 1, score) + } + val topKRow = new UnsafeRow(2) + val bufferHolder = new BufferHolder(topKRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 2) + val scoreWriter = ScoreWriter(unsafeRowWriter, 1) + outputRows.zip(rankNum.map(_._1)).map { case ((score, row), index) => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, index) + scoreWriter.write(score) + topKRow.setTotalSize(bufferHolder.totalSize()) + joinedRow.apply(topKRow, row) + }.iterator + } + + override def clear(): Unit = q.clear() + } + + override def output: Seq[Attribute] = joinType match { + case Inner => topKAttr ++ left.output ++ right.output + } + + override protected final def otherCopyArgs: Seq[AnyRef] = { + scoreExpr :: rankAttr :: Nil + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + val context = TaskContext.get() + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + context.addTaskCompletionListener(_ => relation.close()) + relation + } + + override protected def createResultProjection(): (InternalRow) => InternalRow = joinType match { + case Inner => + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (topKAttr ++ streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) + } + + protected def InnerJoin( + streamedIter: Iterator[InternalRow], + hashedRelation: HashedRelation, + numOutputRows: SQLMetric): Iterator[InternalRow] = { + val joinRow = new JoinedRow + val joinKeysProj = streamSideKeyGenerator() + val joinedIter = streamedIter.flatMap { srow => + joinRow.withLeft(srow) + val joinKeys = joinKeysProj(srow) // `joinKeys` is also a grouping key + val matches = hashedRelation.get(joinKeys) + if (matches != null) { + matches.map(joinRow.withRight).filter(boundCondition).foreach { resultRow => + _priorityQueue.insert(scoreProjection(resultRow).get(0, scoreType), resultRow) + } + val iter = _priorityQueue.get() + _priorityQueue.clear() + iter + } else { + Seq.empty + } + } + val resultProj = createResultProjection() + (joinedIter ++ queue.iterator.toSeq.sortBy(_._1)(reverseScoreOrdering) + .map(_._2)).map { r => + resultProj(r) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => + val hashed = buildHashedRelation(buildIter) + InnerJoin(streamIter, hashed, null) + } + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + left.execute() :: right.execute() :: Nil + } + + // Accessor for generated code + def priorityQueue(): PriorityQueueShim = _priorityQueue + + /** + * Add a state of HashedRelation and return the variable name for it. + */ + private def prepareHashedRelation(ctx: CodegenContext): String = { + // create a name for HashedRelation + val joinExec = ctx.addReferenceObj("joinExec", this) + val relationTerm = ctx.freshName("relation") + val clsName = HashedRelation.getClass.getName.replace("$", "") + ctx.addMutableState(clsName, relationTerm, + v => s""" + | $v = ($clsName) $joinExec.buildHashedRelation(inputs[1]); + | incPeakExecutionMemory($v.estimatedSize()); + """.stripMargin) + relationTerm + } + + /** + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ + private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = leftRow + left.output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, _ => "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, _ => "") + val code = + s""" + |$isNull = $leftRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ + private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = rightRow + right.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).genCode(ctx) + } + } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey(ctx: CodegenContext, leftRow: String): (ExprCode, String) = { + ctx.INPUT_ROW = leftRow + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + private def createScoreVar(ctx: CodegenContext, row: String): ExprCode = { + ctx.INPUT_ROW = row + BindReferences.bindReference(scoreExpr, left.output ++ right.output).genCode(ctx) + } + + private def createResultVars(ctx: CodegenContext, resultRow: String): Seq[ExprCode] = { + ctx.INPUT_ROW = resultRow + output.zipWithIndex.map { case (a, i) => + val value = ctx.freshName("value") + val valueCode = ctx.getValue(resultRow, a.dataType, i.toString) + // declare it as class member, so we can access the column before or in the loop. + ctx.addMutableState(ctx.javaType(a.dataType), value, _ => "") + if (a.nullable) { + val isNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", isNull, _ => "") + val code = + s""" + |$isNull = $resultRow.isNullAt($i); + |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + """.stripMargin + ExprCode(code, isNull, value) + } else { + ExprCode(s"$value = $valueCode;", "false", value) + } + } + } + + /** + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ + private def splitVarsByCondition( + attributes: Seq[Attribute], + variables: Seq[ExprCode]): (String, String) = { + if (condition.isDefined) { + val condRefs = condition.get.references + val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => + condRefs.contains(a) + } + val beforeCond = evaluateVariables(used.map(_._2)) + val afterCond = evaluateVariables(notUsed.map(_._2)) + (beforeCond, afterCond) + } else { + (evaluateVariables(variables), "") + } + } + + override def needCopyResult: Boolean = true + + override def doProduce(ctx: CodegenContext): String = { + val topKJoin = ctx.addReferenceObj("topKJoin", this) + + // Prepare a priority queue for top-K computing + val pQueue = ctx.freshName("queue") + ctx.addMutableState(classOf[PriorityQueueShim].getName, pQueue, + v => s"$v= $topKJoin.priorityQueue();") + + // Prepare variables for a left side + val leftIter = ctx.freshName("leftIter") + ctx.addMutableState("scala.collection.Iterator", leftIter, v => s"$v = inputs[0];") + val leftRow = ctx.freshName("leftRow") + ctx.addMutableState("InternalRow", leftRow, v => "") + val leftVars = createLeftVars(ctx, leftRow) + + // Prepare variables for a right side + val rightRow = ctx.freshName("rightRow") + val rightVars = createRightVar(ctx, rightRow) + + // Build a hashed relation from a right side + val buildRelation = prepareHashedRelation(ctx) + + // Project join keys from a left side + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, leftRow) + + // Prepare variables for joined rows + val joinedRow = ctx.freshName("joinedRow") + val joinedRowCls = classOf[JoinedRow].getName + ctx.addMutableState(joinedRowCls, joinedRow, v => s"$v = new $joinedRowCls();") + + // Project score values from joined rows + val scoreVar = createScoreVar(ctx, joinedRow) + + // Prepare variables for output rows + val resultRow = ctx.freshName("resultRow") + val resultVars = createResultVars(ctx, resultRow) + + val (beforeLoop, condCheck) = if (condition.isDefined) { + // Split the code of creating variables based on whether it's used by condition or not. + val loaded = ctx.freshName("loaded") + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + // Generate code for condition + ctx.currentVars = leftVars ++ rightVars + val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) + // evaluate the columns those used by condition before loop + val before = s""" + |boolean $loaded = false; + |$leftBefore + """.stripMargin + + val checking = s""" + |$rightBefore + |${cond.code} + |if (${cond.isNull} || !${cond.value}) continue; + |if (!$loaded) { + | $loaded = true; + | $leftAfter + |} + |$rightAfter + """.stripMargin + (before, checking) + } else { + (evaluateVariables(leftVars), "") + } + + val numOutput = metricTerm(ctx, "numOutputRows") + + val matches = ctx.freshName("matches") + val topKRows = ctx.freshName("topKRows") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + + s""" + |$leftRow = null; + |while ($leftIter.hasNext()) { + | $leftRow = (InternalRow) $leftIter.next(); + | + | // Generate join key for stream side + | ${keyEv.code} + | + | // Find matches from HashedRelation + | $iteratorCls $matches = $anyNull? null : ($iteratorCls)$buildRelation.get(${keyEv.value}); + | if ($matches == null) continue; + | + | // Join top-K right rows + | while ($matches.hasNext()) { + | ${beforeLoop.trim} + | InternalRow $rightRow = (InternalRow) $matches.next(); + | ${condCheck.trim} + | InternalRow row = $joinedRow.apply($leftRow, $rightRow); + | // Compute a score for the `row` + | ${scoreVar.code} + | $pQueue.insert(${scoreVar.value}, row); + | } + | + | // Get top-K rows + | $iteratorCls $topKRows = $pQueue.get(); + | $pQueue.clear(); + | + | // Output top-K rows + | while ($topKRows.hasNext()) { + | InternalRow $resultRow = (InternalRow) $topKRows.next(); + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | + | if (shouldStop()) return; + |} + """.stripMargin + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala new file mode 100644 index 000000000..2982d9c9b --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala @@ -0,0 +1,636 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.RelationalGroupedDataset +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.Pivot +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.types._ + +/** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * + * @groupname classifier + * @groupname ensemble + * @groupname evaluation + * @groupname topicmodel + * @groupname ftvec.selection + * @groupname ftvec.text + * @groupname ftvec.trans + * @groupname tools.array + * @groupname tools.bits + * @groupname tools.list + * @groupname tools.map + * @groupname tools.matrix + * @groupname tools.math + * + * A list of unsupported functions is as follows: + * * ftvec.conv + * - conv2dense + * - build_bins + */ +final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) { + + /** + * @see hivemall.classifier.KPAPredictUDAF + * @group classifier + */ + def kpa_predict(xh: String, xk: String, w0: String, w1: String, w2: String, w3: String) + : DataFrame = { + checkType(xh, DoubleType) + checkType(xk, DoubleType) + checkType(w0, FloatType) + checkType(w1, FloatType) + checkType(w2, FloatType) + checkType(w3, FloatType) + val udaf = HiveUDAFFunction( + "kpa_predict", + new HiveFunctionWrapper("hivemall.classifier.KPAPredictUDAF"), + Seq(xh, xk, w0, w1, w2, w3).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ensemble.bagging.VotedAvgUDAF + * @group ensemble + */ + def voted_avg(weight: String): DataFrame = { + checkType(weight, DoubleType) + val udaf = HiveUDAFFunction( + "voted_avg", + new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"), + Seq(weight).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ensemble.bagging.WeightVotedAvgUDAF + * @group ensemble + */ + def weight_voted_avg(weight: String): DataFrame = { + checkType(weight, DoubleType) + val udaf = HiveUDAFFunction( + "weight_voted_avg", + new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"), + Seq(weight).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ensemble.ArgminKLDistanceUDAF + * @group ensemble + */ + def argmin_kld(weight: String, conv: String): DataFrame = { + checkType(weight, FloatType) + checkType(conv, FloatType) + val udaf = HiveUDAFFunction( + "argmin_kld", + new HiveFunctionWrapper("hivemall.ensemble.ArgminKLDistanceUDAF"), + Seq(weight, conv).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ensemble.MaxValueLabelUDAF" + * @group ensemble + */ + def max_label(score: String, label: String): DataFrame = { + // checkType(score, DoubleType) + checkType(label, StringType) + val udaf = HiveUDAFFunction( + "max_label", + new HiveFunctionWrapper("hivemall.ensemble.MaxValueLabelUDAF"), + Seq(score, label).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ensemble.MaxRowUDAF + * @group ensemble + */ + def maxrow(score: String, label: String): DataFrame = { + checkType(score, DoubleType) + checkType(label, StringType) + val udaf = HiveUDAFFunction( + "maxrow", + new HiveFunctionWrapper("hivemall.ensemble.MaxRowUDAF"), + Seq(score, label).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.smile.tools.RandomForestEnsembleUDAF + * @group ensemble + */ + @scala.annotation.varargs + def rf_ensemble(yhat: String, others: String*): DataFrame = { + checkType(yhat, IntegerType) + val udaf = HiveUDAFFunction( + "rf_ensemble", + new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"), + (yhat +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.MeanAbsoluteErrorUDAF + * @group evaluation + */ + def mae(predict: String, target: String): DataFrame = { + checkType(predict, DoubleType) + checkType(target, DoubleType) + val udaf = HiveUDAFFunction( + "mae", + new HiveFunctionWrapper("hivemall.evaluation.MeanAbsoluteErrorUDAF"), + Seq(predict, target).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.MeanSquareErrorUDAF + * @group evaluation + */ + def mse(predict: String, target: String): DataFrame = { + checkType(predict, DoubleType) + checkType(target, DoubleType) + val udaf = HiveUDAFFunction( + "mse", + new HiveFunctionWrapper("hivemall.evaluation.MeanSquaredErrorUDAF"), + Seq(predict, target).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.RootMeanSquareErrorUDAF + * @group evaluation + */ + def rmse(predict: String, target: String): DataFrame = { + checkType(predict, DoubleType) + checkType(target, DoubleType) + val udaf = HiveUDAFFunction( + "rmse", + new HiveFunctionWrapper("hivemall.evaluation.RootMeanSquaredErrorUDAF"), + Seq(predict, target).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.R2UDAF + * @group evaluation + */ + def r2(predict: String, target: String): DataFrame = { + checkType(predict, DoubleType) + checkType(target, DoubleType) + val udaf = HiveUDAFFunction( + "r2", + new HiveFunctionWrapper("hivemall.evaluation.R2UDAF"), + Seq(predict, target).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.LogarithmicLossUDAF + * @group evaluation + */ + def logloss(predict: String, target: String): DataFrame = { + checkType(predict, DoubleType) + checkType(target, DoubleType) + val udaf = HiveUDAFFunction( + "logloss", + new HiveFunctionWrapper("hivemall.evaluation.LogarithmicLossUDAF"), + Seq(predict, target).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.F1ScoreUDAF + * @group evaluation + */ + def f1score(predict: String, target: String): DataFrame = { + // checkType(target, ArrayType(IntegerType, false)) + // checkType(predict, ArrayType(IntegerType, false)) + val udaf = HiveUDAFFunction( + "f1score", + new HiveFunctionWrapper("hivemall.evaluation.F1ScoreUDAF"), + Seq(predict, target).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.NDCGUDAF + * @group evaluation + */ + @scala.annotation.varargs + def ndcg(rankItems: String, correctItems: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "ndcg", + new HiveFunctionWrapper("hivemall.evaluation.NDCGUDAF"), + (rankItems +: correctItems +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.PrecisionUDAF + * @group evaluation + */ + @scala.annotation.varargs + def precision_at(rankItems: String, correctItems: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "precision_at", + new HiveFunctionWrapper("hivemall.evaluation.PrecisionUDAF"), + (rankItems +: correctItems +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.RecallUDAF + * @group evaluation + */ + @scala.annotation.varargs + def recall_at(rankItems: String, correctItems: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "recall_at", + new HiveFunctionWrapper("hivemall.evaluation.RecallUDAF"), + (rankItems +: correctItems +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.HitRateUDAF + * @group evaluation + */ + @scala.annotation.varargs + def hitrate(rankItems: String, correctItems: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "hitrate", + new HiveFunctionWrapper("hivemall.evaluation.HitRateUDAF"), + (rankItems +: correctItems +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.MRRUDAF + * @group evaluation + */ + @scala.annotation.varargs + def mrr(rankItems: String, correctItems: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "mrr", + new HiveFunctionWrapper("hivemall.evaluation.MRRUDAF"), + (rankItems +: correctItems +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.MAPUDAF + * @group evaluation + */ + @scala.annotation.varargs + def average_precision(rankItems: String, correctItems: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "average_precision", + new HiveFunctionWrapper("hivemall.evaluation.MAPUDAF"), + (rankItems +: correctItems +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.evaluation.AUCUDAF + * @group evaluation + */ + @scala.annotation.varargs + def auc(args: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "auc", + new HiveFunctionWrapper("hivemall.evaluation.AUCUDAF"), + args.map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.topicmodel.LDAPredictUDAF + * @group topicmodel + */ + @scala.annotation.varargs + def lda_predict(word: String, value: String, label: String, lambda: String, others: String*) + : DataFrame = { + checkType(word, StringType) + checkType(value, DoubleType) + checkType(label, IntegerType) + checkType(lambda, DoubleType) + val udaf = HiveUDAFFunction( + "lda_predict", + new HiveFunctionWrapper("hivemall.topicmodel.LDAPredictUDAF"), + (word +: value +: label +: lambda +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.topicmodel.PLSAPredictUDAF + * @group topicmodel + */ + @scala.annotation.varargs + def plsa_predict(word: String, value: String, label: String, prob: String, others: String*) + : DataFrame = { + checkType(word, StringType) + checkType(value, DoubleType) + checkType(label, IntegerType) + checkType(prob, DoubleType) + val udaf = HiveUDAFFunction( + "plsa_predict", + new HiveFunctionWrapper("hivemall.topicmodel.PLSAPredictUDAF"), + (word +: value +: label +: prob +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ftvec.text.TermFrequencyUDAF + * @group ftvec.text + */ + def tf(text: String): DataFrame = { + checkType(text, StringType) + val udaf = HiveUDAFFunction( + "tf", + new HiveFunctionWrapper("hivemall.ftvec.text.TermFrequencyUDAF"), + Seq(text).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ftvec.trans.OnehotEncodingUDAF + * @group ftvec.trans + */ + @scala.annotation.varargs + def onehot_encoding(feature: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "onehot_encoding", + new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"), + (feature +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF + * @group ftvec.selection + */ + def snr(feature: String, label: String): DataFrame = { + val udaf = HiveUDAFFunction( + "snr", + new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"), + Seq(feature, label).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.array.ArrayAvgGenericUDAF + * @group tools.array + */ + def array_avg(ar: String): DataFrame = { + val udaf = HiveUDAFFunction( + "array_avg", + new HiveFunctionWrapper("hivemall.tools.array.ArrayAvgGenericUDAF"), + Seq(ar).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.array.ArraySumUDAF + * @group tools.array + */ + def array_sum(ar: String): DataFrame = { + val udaf = HiveUDAFFunction( + "array_sum", + new HiveFunctionWrapper("hivemall.tools.array.ArraySumUDAF"), + Seq(ar).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.bits.BitsCollectUDAF + * @group tools.bits + */ + def bits_collect(x: String): DataFrame = { + val udaf = HiveUDAFFunction( + "bits_collect", + new HiveFunctionWrapper("hivemall.tools.bits.BitsCollectUDAF"), + Seq(x).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.list.UDAFToOrderedList + * @group tools.list + */ + @scala.annotation.varargs + def to_ordered_list(value: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "to_ordered_list", + new HiveFunctionWrapper("hivemall.tools.list.UDAFToOrderedList"), + (value +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.map.UDAFToMap + * @group tools.map + */ + def to_map(key: String, value: String): DataFrame = { + val udaf = HiveUDAFFunction( + "to_map", + new HiveFunctionWrapper("hivemall.tools.map.UDAFToMap"), + Seq(key, value).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.map.UDAFToOrderedMap + * @group tools.map + */ + @scala.annotation.varargs + def to_ordered_map(key: String, value: String, others: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "to_ordered_map", + new HiveFunctionWrapper("hivemall.tools.map.UDAFToOrderedMap"), + (key +: value +: others).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.matrix.TransposeAndDotUDAF + * @group tools.matrix + */ + def transpose_and_dot(matrix0_row: String, matrix1_row: String): DataFrame = { + val udaf = HiveUDAFFunction( + "transpose_and_dot", + new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"), + Seq(matrix0_row, matrix1_row).map(df(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * @see hivemall.tools.math.L2NormUDAF + * @group tools.math + */ + def l2_norm(xi: String): DataFrame = { + val udaf = HiveUDAFFunction( + "l2_norm", + new HiveFunctionWrapper("hivemall.tools.math.L2NormUDAF"), + Seq(xi).map(df(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF(Alias(udaf, udaf.prettyName)() :: Nil) + } + + /** + * [[RelationalGroupedDataset]] has the three values as private fields, so, to inject Hivemall + * aggregate functions, we fetch them via Java Reflections. + */ + private val df = getPrivateField[DataFrame]("org$apache$spark$sql$RelationalGroupedDataset$$df") + private val groupingExprs = getPrivateField[Seq[Expression]]("groupingExprs") + private val groupType = getPrivateField[RelationalGroupedDataset.GroupType]("groupType") + + private def getPrivateField[T](name: String): T = { + val field = groupBy.getClass.getDeclaredField(name) + field.setAccessible(true) + field.get(groupBy).asInstanceOf[T] + } + + private def toDF(aggExprs: Seq[Expression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + groupingExprs ++ aggExprs + } else { + aggExprs + } + + val aliasedAgg = aggregates.map(alias) + + groupType match { + case RelationalGroupedDataset.GroupByType => + Dataset.ofRows( + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.RollupType => + Dataset.ofRows( + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.CubeType => + Dataset.ofRows( + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + Dataset.ofRows( + df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + } + } + + private def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyName)() + } + + private def checkType(colName: String, expected: DataType) = { + val dataType = df.resolve(colName).dataType + if (dataType != expected) { + throw new AnalysisException( + s""""$colName" must be $expected, however it is $dataType""") + } + } +} + +object HivemallGroupedDataset { + + /** + * Implicitly inject the [[HivemallGroupedDataset]] into [[RelationalGroupedDataset]]. + */ + implicit def relationalGroupedDatasetToHivemallOne( + groupBy: RelationalGroupedDataset): HivemallGroupedDataset = { + new HivemallGroupedDataset(groupBy) + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala new file mode 100644 index 000000000..8323d2286 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -0,0 +1,2249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import java.util.UUID + +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.HivemallFeature +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical.{Generate, JoinTopK, LogicalPlan} +import org.apache.spark.sql.execution.UserProvidedPlanner +import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +/** + * Hivemall wrapper and some utility functions for DataFrame. These functions below derives + * from `resources/ddl/define-all-as-permanent.hive`. + * + * @groupname regression + * @groupname classifier + * @groupname classifier.multiclass + * @groupname recommend + * @groupname topicmodel + * @groupname geospatial + * @groupname smile + * @groupname xgboost + * @groupname anomaly + * @groupname knn.similarity + * @groupname knn.distance + * @groupname knn.lsh + * @groupname ftvec + * @groupname ftvec.amplify + * @groupname ftvec.hashing + * @groupname ftvec.paring + * @groupname ftvec.scaling + * @groupname ftvec.selection + * @groupname ftvec.conv + * @groupname ftvec.trans + * @groupname ftvec.ranking + * @groupname tools + * @groupname tools.array + * @groupname tools.bits + * @groupname tools.compress + * @groupname tools.map + * @groupname tools.text + * @groupname misc + * + * A list of unsupported functions is as follows: + * * smile + * - guess_attribute_types + * * mapred functions + * - taskid + * - jobid + * - rownum + * - distcache_gets + * - jobconf_gets + * * matrix factorization + * - mf_predict + * - train_mf_sgd + * - train_mf_adagrad + * - train_bprmf + * - bprmf_predict + * * Factorization Machine + * - fm_predict + * - train_fm + * - train_ffm + * - ffm_predict + */ +final class HivemallOps(df: DataFrame) extends Logging { + import internal.HivemallOpsImpl._ + + private lazy val _sparkSession = df.sparkSession + private lazy val _strategy = new UserProvidedPlanner(_sparkSession.sqlContext.conf) + + /** + * @see [[hivemall.regression.GeneralRegressorUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_regressor(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.GeneralRegressorUDTF", + "train_regressor", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.AdaDeltaUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_adadelta_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AdaDeltaUDTF", + "train_adadelta_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.AdaGradUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_adagrad_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AdaGradUDTF", + "train_adagrad_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.AROWRegressionUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_arow_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AROWRegressionUDTF", + "train_arow_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.regression.AROWRegressionUDTF.AROWe]] + * @group regression + */ + @scala.annotation.varargs + def train_arowe_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AROWRegressionUDTF$AROWe", + "train_arowe_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.regression.AROWRegressionUDTF.AROWe2]] + * @group regression + */ + @scala.annotation.varargs + def train_arowe2_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.AROWRegressionUDTF$AROWe2", + "train_arowe2_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.regression.LogressUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_logistic_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.LogressUDTF", + "train_logistic_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF]] + * @group regression + */ + @scala.annotation.varargs + def train_pa1_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF", + "train_pa1_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA1a]] + * @group regression + */ + @scala.annotation.varargs + def train_pa1a_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF$PA1a", + "train_pa1a_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA2]] + * @group regression + */ + @scala.annotation.varargs + def train_pa2_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF$PA2", + "train_pa2_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.regression.PassiveAggressiveRegressionUDTF.PA2a]] + * @group regression + */ + @scala.annotation.varargs + def train_pa2a_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.regression.PassiveAggressiveRegressionUDTF$PA2a", + "train_pa2a_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.GeneralClassifierUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.GeneralClassifierUDTF", + "train_classifier", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.PerceptronUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_perceptron(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PerceptronUDTF", + "train_perceptron", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.PassiveAggressiveUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_pa(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PassiveAggressiveUDTF", + "train_pa", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.PassiveAggressiveUDTF.PA1]] + * @group classifier + */ + @scala.annotation.varargs + def train_pa1(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PassiveAggressiveUDTF$PA1", + "train_pa1", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.PassiveAggressiveUDTF.PA2]] + * @group classifier + */ + @scala.annotation.varargs + def train_pa2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.PassiveAggressiveUDTF$PA2", + "train_pa2", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.ConfidenceWeightedUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_cw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.ConfidenceWeightedUDTF", + "train_cw", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.AROWClassifierUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_arow(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.AROWClassifierUDTF", + "train_arow", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.AROWClassifierUDTF.AROWh]] + * @group classifier + */ + @scala.annotation.varargs + def train_arowh(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.AROWClassifierUDTF$AROWh", + "train_arowh", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.SoftConfideceWeightedUDTF.SCW1]] + * @group classifier + */ + @scala.annotation.varargs + def train_scw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.SoftConfideceWeightedUDTF$SCW1", + "train_scw", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.SoftConfideceWeightedUDTF.SCW1]] + * @group classifier + */ + @scala.annotation.varargs + def train_scw2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.SoftConfideceWeightedUDTF$SCW2", + "train_scw2", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.AdaGradRDAUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_adagrad_rda(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.AdaGradRDAUDTF", + "train_adagrad_rda", + setMixServs(toHivemallFeatures(exprs)), + Seq("feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.KernelExpansionPassiveAggressiveUDTF]] + * @group classifier + */ + @scala.annotation.varargs + def train_kpa(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.KernelExpansionPassiveAggressiveUDTF", + "train_kpa", + setMixServs(toHivemallFeatures(exprs)), + Seq("h", "hk", "w0", "w1", "w2", "w3") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPerceptronUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_perceptron(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPerceptronUDTF", + "train_multiclass_perceptron", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_pa(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF", + "train_multiclass_pa", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF.PA1]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_pa1(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA1", + "train_multiclass_pa1", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF.PA2]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_pa2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassPassiveAggressiveUDTF$PA2", + "train_multiclass_pa2", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassConfidenceWeightedUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_cw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassConfidenceWeightedUDTF", + "train_multiclass_cw", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_arow(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF", + "train_multiclass_arow", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF.AROWh]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_arowh(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassAROWClassifierUDTF$AROWh", + "train_multiclass_arowh", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF.SCW1]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_scw(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW1", + "train_multiclass_scw", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF.SCW2]] + * @group classifier.multiclass + */ + @scala.annotation.varargs + def train_multiclass_scw2(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW2", + "train_multiclass_scw2", + setMixServs(toHivemallFeatures(exprs)), + Seq("label", "feature", "weight", "conv") + ) + } + + /** + * @see [[hivemall.recommend.SlimUDTF]] + * @group recommend + */ + @scala.annotation.varargs + def train_slim(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.recommend.SlimUDTF", + "train_slim", + setMixServs(toHivemallFeatures(exprs)), + Seq("j", "nn", "w") + ) + } + + /** + * @see [[hivemall.topicmodel.LDAUDTF]] + * @group topicmodel + */ + @scala.annotation.varargs + def train_lda(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.topicmodel.LDAUDTF", + "train_lda", + setMixServs(toHivemallFeatures(exprs)), + Seq("topic", "word", "score") + ) + } + + /** + * @see [[hivemall.topicmodel.PLSAUDTF]] + * @group topicmodel + */ + @scala.annotation.varargs + def train_plsa(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.topicmodel.PLSAUDTF", + "train_plsa", + setMixServs(toHivemallFeatures(exprs)), + Seq("topic", "word", "score") + ) + } + + /** + * @see [[hivemall.smile.regression.RandomForestRegressionUDTF]] + * @group smile + */ + @scala.annotation.varargs + def train_randomforest_regressor(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.smile.regression.RandomForestRegressionUDTF", + "train_randomforest_regressor", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests") + ) + } + + /** + * @see [[hivemall.smile.classification.RandomForestClassifierUDTF]] + * @group smile + */ + @scala.annotation.varargs + def train_randomforest_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.smile.classification.RandomForestClassifierUDTF", + "train_randomforest_classifier", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "model_type", "pred_model", "var_importance", "oob_errors", "oob_tests") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.regression.XGBoostRegressionUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.regression.XGBoostRegressionUDTF", + "train_xgboost_regr", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "pred_model") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF", + "train_xgboost_classifier", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "pred_model") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF", + "train_xgboost_multiclass_classifier", + setMixServs(toHivemallFeatures(exprs)), + Seq("model_id", "pred_model") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.tools.XGBoostPredictUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def xgboost_predict(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.tools.XGBoostPredictUDTF", + "xgboost_predict", + setMixServs(toHivemallFeatures(exprs)), + Seq("rowid", "predicted") + ) + } + + /** + * :: Experimental :: + * @see [[hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF]] + * @group xgboost + */ + @Experimental + @scala.annotation.varargs + def xgboost_multiclass_predict(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF", + "xgboost_multiclass_predict", + setMixServs(toHivemallFeatures(exprs)), + Seq("rowid", "label", "probability") + ) + } + + /** + * @see [[hivemall.knn.similarity.DIMSUMMapperUDTF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def dimsum_mapper(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.knn.similarity.DIMSUMMapperUDTF", + "dimsum_mapper", + exprs, + Seq("j", "k", "b_jk") + ) + } + + /** + * @see [[hivemall.knn.lsh.MinHashUDTF]] + * @group knn.lsh + */ + @scala.annotation.varargs + def minhash(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.knn.lsh.MinHashUDTF", + "minhash", + exprs, + Seq("clusterid", "item") + ) + } + + /** + * @see [[hivemall.ftvec.amplify.AmplifierUDTF]] + * @group ftvec.amplify + */ + @scala.annotation.varargs + def amplify(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.amplify.AmplifierUDTF", + "amplify", + exprs, + Seq("clusterid", "item") + ) + } + + /** + * @see [[hivemall.ftvec.amplify.RandomAmplifierUDTF]] + * @group ftvec.amplify + */ + @scala.annotation.varargs + def rand_amplify(exprs: Column*): DataFrame = withTypedPlan { + throw new UnsupportedOperationException("`rand_amplify` not supported yet") + } + + /** + * Amplifies and shuffle data inside partitions. + * @group ftvec.amplify + */ + def part_amplify(xtimes: Column): DataFrame = { + val xtimesInt = xtimes.expr match { + case Literal(v: Any, IntegerType) => v.asInstanceOf[Int] + case e => throw new AnalysisException("`xtimes` must be integer, however " + e) + } + val rdd = df.rdd.mapPartitions({ iter => + val elems = iter.flatMap{ row => + Seq.fill[Row](xtimesInt)(row) + } + // Need to check how this shuffling affects results + scala.util.Random.shuffle(elems) + }, true) + df.sqlContext.createDataFrame(rdd, df.schema) + } + + /** + * Quantifies input columns. + * @see [[hivemall.ftvec.conv.QuantifyColumnsUDTF]] + * @group ftvec.conv + */ + @scala.annotation.varargs + def quantify(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.conv.QuantifyColumnsUDTF", + "quantify", + exprs, + (0 until exprs.size - 1).map(i => s"c$i") + ) + } + + /** + * @see [[hivemall.ftvec.trans.BinarizeLabelUDTF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def binarize_label(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.trans.BinarizeLabelUDTF", + "binarize_label", + exprs, + (0 until exprs.size - 1).map(i => s"c$i") + ) + } + + /** + * @see [[hivemall.ftvec.trans.QuantifiedFeaturesUDTF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def quantified_features(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.trans.QuantifiedFeaturesUDTF", + "quantified_features", + exprs, + Seq("features") + ) + } + + /** + * @see [[hivemall.ftvec.ranking.BprSamplingUDTF]] + * @group ftvec.ranking + */ + @scala.annotation.varargs + def bpr_sampling(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.ranking.BprSamplingUDTF", + "bpr_sampling", + exprs, + Seq("user", "pos_item", "neg_item") + ) + } + + /** + * @see [[hivemall.ftvec.ranking.ItemPairsSamplingUDTF]] + * @group ftvec.ranking + */ + @scala.annotation.varargs + def item_pairs_sampling(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.ranking.ItemPairsSamplingUDTF", + "item_pairs_sampling", + exprs, + Seq("pos_item_id", "neg_item_id") + ) + } + + /** + * @see [[hivemall.ftvec.ranking.PopulateNotInUDTF]] + * @group ftvec.ranking + */ + @scala.annotation.varargs + def populate_not_in(exprs: Column*): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.ftvec.ranking.PopulateNotInUDTF", + "populate_not_in", + exprs, + Seq("item") + ) + } + + /** + * Splits Seq[String] into pieces. + * @group ftvec + */ + def explode_array(features: Column): DataFrame = { + df.explode(features) { case Row(v: Seq[_]) => + // Type erasure removes the component type in Seq + v.map(s => HivemallFeature(s.asInstanceOf[String])) + } + } + + /** + * Splits [[Vector]] into pieces. + * @group ftvec + */ + def explode_vector(features: Column): DataFrame = { + val elementSchema = StructType( + StructField("feature", StringType) :: StructField("weight", DoubleType) :: Nil) + val explodeFunc: Row => TraversableOnce[InternalRow] = (row: Row) => { + row.get(0) match { + case dv: DenseVector => + dv.values.zipWithIndex.map { + case (value, index) => + InternalRow(UTF8String.fromString(s"$index"), value) + } + case sv: SparseVector => + sv.values.zip(sv.indices).map { + case (value, index) => + InternalRow(UTF8String.fromString(s"$index"), value) + } + } + } + withTypedPlan { + Generate( + UserDefinedGenerator(elementSchema, explodeFunc, features.expr :: Nil), + unrequiredChildIndex = Seq.empty, + outer = false, None, + generatorOutput = Nil, + df.logicalPlan) + } + } + + /** + * @see [[hivemall.tools.GenerateSeriesUDTF]] + * @group tools + */ + def generate_series(start: Column, end: Column): DataFrame = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.tools.GenerateSeriesUDTF", + "generate_series", + start :: end :: Nil, + Seq("generate_series") + ) + } + + /** + * Returns `top-k` records for each `group`. + * @group misc + */ + def each_top_k(k: Column, score: Column, group: Column*): DataFrame = withTypedPlan { + val kInt = k.expr match { + case Literal(v: Any, IntegerType) => v.asInstanceOf[Int] + case e => throw new AnalysisException("`k` must be integer, however " + e) + } + if (kInt == 0) { + throw new AnalysisException("`k` must not have 0") + } + val clusterDf = df.repartition(group: _*).sortWithinPartitions(group: _*) + .select(score, Column("*")) + val analyzedPlan = clusterDf.queryExecution.analyzed + val inputAttrs = analyzedPlan.output + val scoreExpr = BindReferences.bindReference(analyzedPlan.expressions.head, inputAttrs) + val groupNames = group.map { _.expr match { + case ne: NamedExpression => ne.name + case ua: UnresolvedAttribute => ua.name + }} + val groupExprs = analyzedPlan.expressions.filter { + case ne: NamedExpression => groupNames.contains(ne.name) + }.map { e => + BindReferences.bindReference(e, inputAttrs) + } + val rankField = StructField("rank", IntegerType) + Generate( + generator = EachTopK( + k = kInt, + scoreExpr = scoreExpr, + groupExprs = groupExprs, + elementSchema = StructType( + rankField +: inputAttrs.map(d => StructField(d.name, d.dataType)) + ), + children = inputAttrs + ), + unrequiredChildIndex = Seq.empty, + outer = false, + qualifier = None, + generatorOutput = Seq(rankField.name).map(UnresolvedAttribute(_)) ++ inputAttrs, + child = analyzedPlan + ) + } + + /** + * :: Experimental :: + * Joins input two tables with the given keys and the top-k highest `score` values. + * @group misc + */ + @Experimental + def top_k_join(k: Column, right: DataFrame, joinExprs: Column, score: Column) + : DataFrame = withTypedPlanInCustomStrategy { + val kInt = k.expr match { + case Literal(v: Any, IntegerType) => v.asInstanceOf[Int] + case e => throw new AnalysisException("`k` must be integer, however " + e) + } + if (kInt == 0) { + throw new AnalysisException("`k` must not have 0") + } + JoinTopK(kInt, df.logicalPlan, right.logicalPlan, Inner, Option(joinExprs.expr))(score.named) + } + + private def doFlatten(schema: StructType, separator: Char, prefixParts: Seq[String] = Seq.empty) + : Seq[Column] = { + schema.fields.flatMap { f => + val colNameParts = prefixParts :+ f.name + f.dataType match { + case st: StructType => + doFlatten(st, separator, colNameParts) + case _ => + col(colNameParts.mkString(".")).as(colNameParts.mkString(separator.toString)) :: Nil + } + } + } + + // Converts string representation of a character to actual character + @throws[IllegalArgumentException] + private def toChar(str: String): Char = { + if (str.length == 1) { + str.charAt(0) match { + case '$' | '_' | '.' => str.charAt(0) + case _ => throw new IllegalArgumentException( + "Must use '$', '_', or '.' for separator, but got " + str) + } + } else { + throw new IllegalArgumentException( + s"Separator cannot be more than one character: $str") + } + } + + /** + * Flattens a nested schema into a flat one. + * @group misc + * + * For example: + * {{{ + * scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF() + * scala> df.printSchema + * root + * |-- _1: integer (nullable = false) + * |-- _2: struct (nullable = true) + * | |-- _1: integer (nullable = false) + * | |-- _2: struct (nullable = true) + * | | |-- _1: double (nullable = false) + * | | |-- _2: string (nullable = true) + * |-- _3: struct (nullable = true) + * | |-- _1: integer (nullable = false) + * | |-- _2: double (nullable = false) + * + * scala> df.flatten(separator = "$").printSchema + * root + * |-- _1: integer (nullable = false) + * |-- _2$_1: integer (nullable = true) + * |-- _2$_2$_1: double (nullable = true) + * |-- _2$_2$_2: string (nullable = true) + * |-- _3$_1: integer (nullable = true) + * |-- _3$_2: double (nullable = true) + * }}} + */ + def flatten(separator: String = "$"): DataFrame = + df.select(doFlatten(df.schema, toChar(separator)): _*) + + /** + * @see [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]] + * @group misc + */ + @scala.annotation.varargs + def lr_datagen(exprs: Column*): Dataset[Row] = withTypedPlan { + planHiveGenericUDTF( + df, + "hivemall.dataset.LogisticRegressionDataGeneratorUDTFWrapper", + "lr_datagen", + exprs, + Seq("label", "features") + ) + } + + /** + * Returns all the columns as Seq[Column] in this [[DataFrame]]. + */ + private[sql] def cols: Seq[Column] = { + df.schema.fields.map(col => df.col(col.name)).toSeq + } + + /** + * :: Experimental :: + * If a parameter '-mix' does not exist in a 3rd argument, + * set it from an environmental variable + * 'HIVEMALL_MIX_SERVERS'. + * + * TODO: This could work if '--deploy-mode' has 'client'; + * otherwise, we need to set HIVEMALL_MIX_SERVERS + * in all possible spark workers. + */ + @Experimental + private def setMixServs(exprs: Seq[Column]): Seq[Column] = { + val mixes = System.getenv("HIVEMALL_MIX_SERVERS") + if (mixes != null && !mixes.isEmpty()) { + val groupId = df.sqlContext.sparkContext.applicationId + "-" + UUID.randomUUID + logInfo(s"set '${mixes}' as default mix servers (session: ${groupId})") + exprs.size match { + case 2 => exprs :+ Column( + Literal.create(s"-mix ${mixes} -mix_session ${groupId}", StringType)) + /** TODO: Add codes in the case where exprs.size == 3. */ + case _ => exprs + } + } else { + exprs + } + } + + /** + * If the input is a [[Vector]], transform it into Hivemall features. + */ + @inline private def toHivemallFeatures(exprs: Seq[Column]): Seq[Column] = { + df.select(exprs: _*).queryExecution.analyzed.schema.zip(exprs).map { + case (StructField(_, _: VectorUDT, _, _), c) => HivemallUtils.to_hivemall_features(c) + case (_, c) => c + } + } + + /** + * A convenient function to wrap a logical plan and produce a DataFrame. + */ + @inline private def withTypedPlan(logicalPlan: => LogicalPlan): DataFrame = { + val queryExecution = _sparkSession.sessionState.executePlan(logicalPlan) + val outputSchema = queryExecution.sparkPlan.schema + new Dataset[Row](df.sparkSession, queryExecution, RowEncoder(outputSchema)) + } + + @inline private def withTypedPlanInCustomStrategy(logicalPlan: => LogicalPlan) + : DataFrame = { + // Inject custom strategies + if (!_sparkSession.experimental.extraStrategies.contains(_strategy)) { + _sparkSession.experimental.extraStrategies = Seq(_strategy) + } + withTypedPlan(logicalPlan) + } +} + +object HivemallOps { + import internal.HivemallOpsImpl._ + + /** + * Implicitly inject the [[HivemallOps]] into [[DataFrame]]. + */ + implicit def dataFrameToHivemallOps(df: DataFrame): HivemallOps = + new HivemallOps(df) + + /** + * @see [[hivemall.HivemallVersionUDF]] + * @group misc + */ + def hivemall_version(): Column = withExpr { + planHiveUDF( + "hivemall.HivemallVersionUDF", + "hivemall_version", + Nil + ) + } + + /** + * @see [[hivemall.geospatial.TileUDF]] + * @group geospatial + */ + def tile(lat: Column, lon: Column, zoom: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.geospatial.TileUDF", + "tile", + lat :: lon :: zoom :: Nil + ) + } + + /** + * @see [[hivemall.geospatial.MapURLUDF]] + * @group geospatial + */ + @scala.annotation.varargs + def map_url(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.geospatial.MapURLUDF", + "map_url", + exprs + ) + } + + /** + * @see [[hivemall.geospatial.Lat2TileYUDF]] + * @group geospatial + */ + def lat2tiley(lat: Column, zoom: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.geospatial.Lat2TileYUDF", + "lat2tiley", + lat :: zoom :: Nil + ) + } + + /** + * @see [[hivemall.geospatial.Lon2TileXUDF]] + * @group geospatial + */ + def lon2tilex(lon: Column, zoom: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.geospatial.Lon2TileXUDF", + "lon2tilex", + lon :: zoom :: Nil + ) + } + + /** + * @see [[hivemall.geospatial.TileX2LonUDF]] + * @group geospatial + */ + def tilex2lon(x: Column, zoom: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.geospatial.TileX2LonUDF", + "tilex2lon", + x :: zoom :: Nil + ) + } + + /** + * @see [[hivemall.geospatial.TileY2LatUDF]] + * @group geospatial + */ + def tiley2lat(y: Column, zoom: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.geospatial.TileY2LatUDF", + "tiley2lat", + y :: zoom :: Nil + ) + } + + /** + * @see [[hivemall.geospatial.HaversineDistanceUDF]] + * @group geospatial + */ + @scala.annotation.varargs + def haversine_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.geospatial.HaversineDistanceUDF", + "haversine_distance", + exprs + ) + } + + /** + * @see [[hivemall.smile.tools.TreePredictUDF]] + * @group smile + */ + @scala.annotation.varargs + def tree_predict(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.smile.tools.TreePredictUDF", + "tree_predict", + exprs + ) + } + + /** + * @see [[hivemall.smile.tools.TreeExportUDF]] + * @group smile + */ + @scala.annotation.varargs + def tree_export(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.smile.tools.TreeExportUDF", + "tree_export", + exprs + ) + } + + /** + * @see [[hivemall.anomaly.ChangeFinderUDF]] + * @group anomaly + */ + @scala.annotation.varargs + def changefinder(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.anomaly.ChangeFinderUDF", + "changefinder", + exprs + ) + } + + /** + * @see [[hivemall.anomaly.SingularSpectrumTransformUDF]] + * @group anomaly + */ + @scala.annotation.varargs + def sst(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.anomaly.SingularSpectrumTransformUDF", + "sst", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.CosineSimilarityUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def cosine_similarity(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.similarity.CosineSimilarityUDF", + "cosine_similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.JaccardIndexUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def jaccard_similarity(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.similarity.JaccardIndexUDF", + "jaccard_similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.AngularSimilarityUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def angular_similarity(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.similarity.AngularSimilarityUDF", + "angular_similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.EuclidSimilarity]] + * @group knn.similarity + */ + @scala.annotation.varargs + def euclid_similarity(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.similarity.EuclidSimilarity", + "euclid_similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.similarity.Distance2SimilarityUDF]] + * @group knn.similarity + */ + @scala.annotation.varargs + def distance2similarity(exprs: Column*): Column = withExpr { + // TODO: Need a wrapper class because of using unsupported types + planHiveGenericUDF( + "hivemall.knn.similarity.Distance2SimilarityUDF", + "distance2similarity", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.HammingDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def hamming_distance(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.distance.HammingDistanceUDF", + "hamming_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.PopcountUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def popcnt(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.distance.PopcountUDF", + "popcnt", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.KLDivergenceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def kld(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.distance.KLDivergenceUDF", + "kld", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.EuclidDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def euclid_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.EuclidDistanceUDF", + "euclid_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.CosineDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def cosine_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.CosineDistanceUDF", + "cosine_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.AngularDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def angular_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.AngularDistanceUDF", + "angular_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.JaccardDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def jaccard_distance(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.distance.JaccardDistanceUDF", + "jaccard_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.ManhattanDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def manhattan_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.ManhattanDistanceUDF", + "manhattan_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.distance.MinkowskiDistanceUDF]] + * @group knn.distance + */ + @scala.annotation.varargs + def minkowski_distance(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.distance.MinkowskiDistanceUDF", + "minkowski_distance", + exprs + ) + } + + /** + * @see [[hivemall.knn.lsh.bBitMinHashUDF]] + * @group knn.lsh + */ + @scala.annotation.varargs + def bbit_minhash(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.knn.lsh.bBitMinHashUDF", + "bbit_minhash", + exprs + ) + } + + /** + * @see [[hivemall.knn.lsh.MinHashesUDFWrapper]] + * @group knn.lsh + */ + @scala.annotation.varargs + def minhashes(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.knn.lsh.MinHashesUDFWrapper", + "minhashes", + exprs + ) + } + + /** + * Returns new features with `1.0` (bias) appended to the input features. + * @see [[hivemall.ftvec.AddBiasUDFWrapper]] + * @group ftvec + */ + def add_bias(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.AddBiasUDFWrapper", + "add_bias", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.ExtractFeatureUDFWrapper]] + * @group ftvec + * + * TODO: This throws java.lang.ClassCastException because + * HiveInspectors.toInspector has a bug in spark. + * Need to fix it later. + */ + def extract_feature(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.ExtractFeatureUDFWrapper", + "extract_feature", + expr :: Nil + ) + }.as("feature") + + /** + * @see [[hivemall.ftvec.ExtractWeightUDFWrapper]] + * @group ftvec + * + * TODO: This throws java.lang.ClassCastException because + * HiveInspectors.toInspector has a bug in spark. + * Need to fix it later. + */ + def extract_weight(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.ExtractWeightUDFWrapper", + "extract_weight", + expr :: Nil + ) + }.as("value") + + /** + * @see [[hivemall.ftvec.AddFeatureIndexUDFWrapper]] + * @group ftvec + */ + def add_feature_index(features: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.AddFeatureIndexUDFWrapper", + "add_feature_index", + features :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.SortByFeatureUDFWrapper]] + * @group ftvec + */ + def sort_by_feature(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.SortByFeatureUDFWrapper", + "sort_by_feature", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.hashing.MurmurHash3UDF]] + * @group ftvec.hashing + */ + def mhash(expr: Column): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.hashing.MurmurHash3UDF", + "mhash", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.hashing.Sha1UDF]] + * @group ftvec.hashing + */ + @scala.annotation.varargs + def sha1(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.hashing.Sha1UDF", + "sha1", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.hashing.ArrayHashValuesUDF]] + * @group ftvec.hashing + */ + @scala.annotation.varargs + def array_hash_values(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.hashing.ArrayHashValuesUDF", + "array_hash_values", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF]] + * @group ftvec.hashing + */ + @scala.annotation.varargs + def prefixed_hash_values(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF", + "prefixed_hash_values", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.hashing.FeatureHashingUDF]] + * @group ftvec.hashing + */ + @scala.annotation.varargs + def feature_hashing(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.hashing.FeatureHashingUDF", + "feature_hashing", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.pairing.PolynomialFeaturesUDF]] + * @group ftvec.paring + */ + @scala.annotation.varargs + def polynomial_features(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.pairing.PolynomialFeaturesUDF", + "polynomial_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.pairing.PoweredFeaturesUDF]] + * @group ftvec.paring + */ + @scala.annotation.varargs + def powered_features(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.pairing.PoweredFeaturesUDF", + "powered_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.scaling.RescaleUDF]] + * @group ftvec.scaling + */ + def rescale(value: Column, max: Column, min: Column): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.scaling.RescaleUDF", + "rescale", + value.cast(FloatType) :: max :: min :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.scaling.ZScoreUDF]] + * @group ftvec.scaling + */ + @scala.annotation.varargs + def zscore(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.scaling.ZScoreUDF", + "zscore", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.scaling.L2NormalizationUDFWrapper]] + * @group ftvec.scaling + */ + def l2_normalize(expr: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.scaling.L2NormalizationUDFWrapper", + "normalize", + expr :: Nil + ) + } + + /** + * @see [[hivemall.ftvec.selection.ChiSquareUDF]] + * @group ftvec.selection + */ + def chi2(observed: Column, expected: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.selection.ChiSquareUDF", + "chi2", + Seq(observed, expected) + ) + } + + /** + * @see [[hivemall.ftvec.conv.ToDenseFeaturesUDF]] + * @group ftvec.conv + */ + @scala.annotation.varargs + def to_dense_features(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.conv.ToDenseFeaturesUDF", + "to_dense_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.conv.ToSparseFeaturesUDF]] + * @group ftvec.conv + */ + @scala.annotation.varargs + def to_sparse_features(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.ftvec.conv.ToSparseFeaturesUDF", + "to_sparse_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.binning.FeatureBinningUDF]] + * @group ftvec.conv + */ + @scala.annotation.varargs + def feature_binning(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.binning.FeatureBinningUDF", + "feature_binning", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.VectorizeFeaturesUDF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def vectorize_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.VectorizeFeaturesUDF", + "vectorize_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.CategoricalFeaturesUDF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def categorical_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.CategoricalFeaturesUDF", + "categorical_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.FFMFeaturesUDF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def ffm_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.FFMFeaturesUDF", + "ffm_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.IndexedFeatures]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def indexed_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.IndexedFeatures", + "indexed_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.QuantitativeFeaturesUDF]] + * @group ftvec.trans + */ + @scala.annotation.varargs + def quantitative_features(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.QuantitativeFeaturesUDF", + "quantitative_features", + exprs + ) + } + + /** + * @see [[hivemall.ftvec.trans.AddFieldIndicesUDF]] + * @group ftvec.trans + */ + def add_field_indices(features: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.ftvec.trans.AddFieldIndicesUDF", + "add_field_indices", + features :: Nil + ) + } + + /** + * @see [[hivemall.tools.ConvertLabelUDF]] + * @group tools + */ + def convert_label(label: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.ConvertLabelUDF", + "convert_label", + label :: Nil + ) + } + + /** + * @see [[hivemall.tools.RankSequenceUDF]] + * @group tools + */ + def x_rank(key: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.RankSequenceUDF", + "x_rank", + key :: Nil + ) + } + + /** + * @see [[hivemall.tools.array.AllocFloatArrayUDF]] + * @group tools.array + */ + def float_array(nDims: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.array.AllocFloatArrayUDF", + "float_array", + nDims :: Nil + ) + } + + /** + * @see [[hivemall.tools.array.ArrayRemoveUDF]] + * @group tools.array + */ + def array_remove(original: Column, target: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.array.ArrayRemoveUDF", + "array_remove", + original :: target :: Nil + ) + } + + /** + * @see [[hivemall.tools.array.SortAndUniqArrayUDF]] + * @group tools.array + */ + def sort_and_uniq_array(ar: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.array.SortAndUniqArrayUDF", + "sort_and_uniq_array", + ar :: Nil + ) + } + + /** + * @see [[hivemall.tools.array.SubarrayEndWithUDF]] + * @group tools.array + */ + def subarray_endwith(original: Column, key: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.array.SubarrayEndWithUDF", + "subarray_endwith", + original :: key :: Nil + ) + } + + /** + * @see [[hivemall.tools.array.ArrayConcatUDF]] + * @group tools.array + */ + @scala.annotation.varargs + def array_concat(arrays: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.array.ArrayConcatUDF", + "array_concat", + arrays + ) + } + + /** + * @see [[hivemall.tools.array.SubarrayUDF]] + * @group tools.array + */ + def subarray(original: Column, fromIndex: Column, toIndex: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.array.SubarrayUDF", + "subarray", + original :: fromIndex :: toIndex :: Nil + ) + } + + /** + * @see [[hivemall.tools.array.ToStringArrayUDF]] + * @group tools.array + */ + def to_string_array(ar: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.array.ToStringArrayUDF", + "to_string_array", + ar :: Nil + ) + } + + /** + * @see [[hivemall.tools.array.ArrayIntersectUDF]] + * @group tools.array + */ + @scala.annotation.varargs + def array_intersect(arrays: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.array.ArrayIntersectUDF", + "array_intersect", + arrays + ) + } + + /** + * @see [[hivemall.tools.array.SelectKBestUDF]] + * @group tools.array + */ + def select_k_best(X: Column, importanceList: Column, k: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.array.SelectKBestUDF", + "select_k_best", + Seq(X, importanceList, k) + ) + } + + /** + * @see [[hivemall.tools.bits.ToBitsUDF]] + * @group tools.bits + */ + def to_bits(indexes: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.bits.ToBitsUDF", + "to_bits", + indexes :: Nil + ) + } + + /** + * @see [[hivemall.tools.bits.UnBitsUDF]] + * @group tools.bits + */ + def unbits(bitset: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.bits.UnBitsUDF", + "unbits", + bitset :: Nil + ) + } + + /** + * @see [[hivemall.tools.bits.BitsORUDF]] + * @group tools.bits + */ + @scala.annotation.varargs + def bits_or(bits: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.bits.BitsORUDF", + "bits_or", + bits + ) + } + + /** + * @see [[hivemall.tools.compress.InflateUDF]] + * @group tools.compress + */ + @scala.annotation.varargs + def inflate(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.compress.InflateUDF", + "inflate", + exprs + ) + } + + /** + * @see [[hivemall.tools.compress.DeflateUDF]] + * @group tools.compress + */ + @scala.annotation.varargs + def deflate(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.compress.DeflateUDF", + "deflate", + exprs + ) + } + + /** + * @see [[hivemall.tools.map.MapGetSumUDF]] + * @group tools.map + */ + @scala.annotation.varargs + def map_get_sum(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.tools.map.MapGetSumUDF", + "map_get_sum", + exprs + ) + } + + /** + * @see [[hivemall.tools.map.MapTailNUDF]] + * @group tools.map + */ + @scala.annotation.varargs + def map_tail_n(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.map.MapTailNUDF", + "map_tail_n", + exprs + ) + } + + /** + * @see [[hivemall.tools.text.TokenizeUDF]] + * @group tools.text + */ + @scala.annotation.varargs + def tokenize(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.tools.text.TokenizeUDF", + "tokenize", + exprs + ) + } + + /** + * @see [[hivemall.tools.text.StopwordUDF]] + * @group tools.text + */ + def is_stopword(word: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.text.StopwordUDF", + "is_stopword", + word :: Nil + ) + } + + /** + * @see [[hivemall.tools.text.SingularizeUDF]] + * @group tools.text + */ + def singularize(word: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.text.SingularizeUDF", + "singularize", + word :: Nil + ) + } + + /** + * @see [[hivemall.tools.text.SplitWordsUDF]] + * @group tools.text + */ + @scala.annotation.varargs + def split_words(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.tools.text.SplitWordsUDF", + "split_words", + exprs + ) + } + + /** + * @see [[hivemall.tools.text.NormalizeUnicodeUDF]] + * @group tools.text + */ + @scala.annotation.varargs + def normalize_unicode(exprs: Column*): Column = withExpr { + planHiveUDF( + "hivemall.tools.text.NormalizeUnicodeUDF", + "normalize_unicode", + exprs + ) + } + + /** + * @see [[hivemall.tools.text.Base91UDF]] + * @group tools.text + */ + def base91(bin: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.text.Base91UDF", + "base91", + bin :: Nil + ) + } + + /** + * @see [[hivemall.tools.text.Unbase91UDF]] + * @group tools.text + */ + def unbase91(base91String: Column): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.text.Unbase91UDF", + "unbase91", + base91String :: Nil + ) + } + + /** + * @see [[hivemall.tools.text.WordNgramsUDF]] + * @group tools.text + */ + def word_ngrams(words: Column, minSize: Column, maxSize: Column): Column = withExpr { + planHiveUDF( + "hivemall.tools.text.WordNgramsUDF", + "word_ngrams", + words :: minSize :: maxSize :: Nil + ) + } + + /** + * @see [[hivemall.tools.math.SigmoidGenericUDF]] + * @group misc + */ + def sigmoid(expr: Column): Column = { + val one: () => Literal = () => Literal.create(1.0, DoubleType) + Column(one()) / (Column(one()) + exp(-expr)) + } + + /** + * @see [[hivemall.tools.mapred.RowIdUDFWrapper]] + * @group misc + */ + def rowid(): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.mapred.RowIdUDFWrapper", + "rowid", + Nil + ) + }.as("rowid") + + /** + * Parses a column containing a CSV string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * @group misc + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the csv string + * @param options options to control how the csv is parsed. accepts the same options and the + * csv data source. + */ + def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + CsvToStruct(schema, options, e.expr) + } + + /** + * Parses a column containing a CSV string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * @group misc + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the json string + */ + def from_csv(e: Column, schema: StructType): Column = + from_csv(e, schema, Map.empty[String, String]) + + /** + * Converts a column containing a [[StructType]] into a CSV string with the specified schema. + * Throws an exception, in the case of an unsupported type. + * @group misc + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + */ + def to_csv(e: Column, options: Map[String, String]): Column = withExpr { + StructToCsv(options, e.expr) + } + + /** + * Converts a column containing a [[StructType]] into a CSV string with the specified schema. + * Throws an exception, in the case of an unsupported type. + * @group misc + * + * @param e a struct column. + */ + def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String]) + + /** + * A convenient function to wrap an expression and produce a Column. + */ + @inline private def withExpr(expr: Expression): Column = Column(expr) +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala new file mode 100644 index 000000000..70cf00b92 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallUtils.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +object HivemallUtils { + + // # of maximum dimensions for feature vectors + private[this] val maxDims = 100000000 + + /** + * Check whether the given schema contains a column of the required data type. + * @param colName column name + * @param dataType required column data type + */ + private[this] def checkColumnType(schema: StructType, colName: String, dataType: DataType) + : Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Column $colName must be of type $dataType but was actually $actualDataType.") + } + + def to_vector_func(dense: Boolean, dims: Int): Seq[String] => Vector = { + if (dense) { + // Dense features + i: Seq[String] => { + val features = new Array[Double](dims) + i.map { ft => + val s = ft.split(":").ensuring(_.size == 2) + features(s(0).toInt) = s(1).toDouble + } + Vectors.dense(features) + } + } else { + // Sparse features + i: Seq[String] => { + val features = i.map { ft => + // val s = ft.split(":").ensuring(_.size == 2) + val s = ft.split(":") + (s(0).toInt, s(1).toDouble) + } + Vectors.sparse(dims, features) + } + } + } + + def to_hivemall_features_func(): Vector => Array[String] = { + case dv: DenseVector => + dv.values.zipWithIndex.map { + case (value, index) => s"$index:$value" + } + case sv: SparseVector => + sv.values.zip(sv.indices).map { + case (value, index) => s"$index:$value" + } + case v => + throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") + } + + def append_bias_func(): Vector => Vector = { + case dv: DenseVector => + val inputValues = dv.values + val inputLength = inputValues.length + val outputValues = Array.ofDim[Double](inputLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputLength) + outputValues(inputLength) = 1.0 + Vectors.dense(outputValues) + case sv: SparseVector => + val inputValues = sv.values + val inputIndices = sv.indices + val inputValuesLength = inputValues.length + val dim = sv.size + val outputValues = Array.ofDim[Double](inputValuesLength + 1) + val outputIndices = Array.ofDim[Int](inputValuesLength + 1) + System.arraycopy(inputValues, 0, outputValues, 0, inputValuesLength) + System.arraycopy(inputIndices, 0, outputIndices, 0, inputValuesLength) + outputValues(inputValuesLength) = 1.0 + outputIndices(inputValuesLength) = dim + Vectors.sparse(dim + 1, outputIndices, outputValues) + case v => + throw new IllegalArgumentException(s"Do not support vector type ${v.getClass}") + } + + /** + * Transforms Hivemall features into a [[Vector]]. + */ + def to_vector(dense: Boolean = false, dims: Int = maxDims): UserDefinedFunction = { + udf(to_vector_func(dense, dims)) + } + + /** + * Transforms a [[Vector]] into Hivemall features. + */ + def to_hivemall_features: UserDefinedFunction = udf(to_hivemall_features_func) + + /** + * Returns a new [[Vector]] with `1.0` (bias) appended to the input [[Vector]]. + * @group ftvec + */ + def append_bias: UserDefinedFunction = udf(append_bias_func) + + /** + * Builds a [[Vector]]-based model from a table of Hivemall models + */ + def vectorized_model(df: DataFrame, dense: Boolean = false, dims: Int = maxDims) + : UserDefinedFunction = { + checkColumnType(df.schema, "feature", StringType) + checkColumnType(df.schema, "weight", DoubleType) + + import df.sqlContext.implicits._ + val intercept = df + .where($"feature" === "0") + .select($"weight") + .map { case Row(weight: Double) => weight} + .reduce(_ + _) + val weights = to_vector_func(dense, dims)( + df.select($"feature", $"weight") + .where($"feature" !== "0") + .map { case Row(label: String, feature: Double) => s"${label}:$feature"} + .collect.toSeq) + + udf((input: Vector) => BLAS.dot(input, weights) + intercept) + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala new file mode 100644 index 000000000..fdd2cafe5 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/internal/HivemallOpsImpl.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive.internal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan} +import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper + +/** + * This is an implementation class for [[org.apache.spark.sql.hive.HivemallOps]]. + * This class mainly uses the internal Spark classes (e.g., `Generate` and `HiveGenericUDTF`) that + * have unstable interfaces (so, these interfaces may evolve in upcoming releases). + * Therefore, the objective of this class is to extract these unstable parts + * from [[org.apache.spark.sql.hive.HivemallOps]]. + */ +private[hive] object HivemallOpsImpl extends Logging { + + def planHiveUDF( + className: String, + funcName: String, + argumentExprs: Seq[Column]): Expression = { + HiveSimpleUDF( + name = funcName, + funcWrapper = new HiveFunctionWrapper(className), + children = argumentExprs.map(_.expr) + ) + } + + def planHiveGenericUDF( + className: String, + funcName: String, + argumentExprs: Seq[Column]): Expression = { + HiveGenericUDF( + name = funcName, + funcWrapper = new HiveFunctionWrapper(className), + children = argumentExprs.map(_.expr) + ) + } + + def planHiveGenericUDTF( + df: DataFrame, + className: String, + funcName: String, + argumentExprs: Seq[Column], + outputAttrNames: Seq[String]): LogicalPlan = { + Generate( + generator = HiveGenericUDTF( + name = funcName, + funcWrapper = new HiveFunctionWrapper(className), + children = argumentExprs.map(_.expr) + ), + unrequiredChildIndex = Seq.empty, + outer = false, + qualifier = None, + generatorOutput = outputAttrNames.map(UnresolvedAttribute(_)), + child = df.logicalPlan) + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala new file mode 100644 index 000000000..65cdf2448 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/source/XGBoostFileFormat.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive.source + +import java.io.File +import java.io.IOException +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path} +import org.apache.hadoop.io.IOUtils +import org.apache.hadoop.io.compress.GzipCodec +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.util.ReflectionUtils + +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +private[source] final class XGBoostOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter { + + private val hadoopConf = new SerializableConfiguration(new Configuration()) + + override def write(row: InternalRow): Unit = { + val fields = row.toSeq(dataSchema) + val model = fields(1).asInstanceOf[Array[Byte]] + val filePath = new Path(new URI(s"$path")) + val fs = filePath.getFileSystem(hadoopConf.value) + val outputFile = fs.create(filePath) + outputFile.write(model) + outputFile.close() + } + + override def close(): Unit = {} +} + +object XGBoostOutputWriter { + + /** Returns the compression codec extension to be used in a file name, e.g. ".gzip"). */ + def getCompressionExtension(context: TaskAttemptContext): String = { + if (FileOutputFormat.getCompressOutput(context)) { + val codecClass = FileOutputFormat.getOutputCompressorClass(context, classOf[GzipCodec]) + ReflectionUtils.newInstance(codecClass, context.getConfiguration).getDefaultExtension + } else { + "" + } + } +} + +final class XGBoostFileFormat extends FileFormat with DataSourceRegister { + + override def shortName(): String = "libxgboost" + + override def toString: String = "XGBoost" + + private def verifySchema(dataSchema: StructType): Unit = { + if ( + dataSchema.size != 2 || + !dataSchema(0).dataType.sameType(StringType) || + !dataSchema(1).dataType.sameType(BinaryType) + ) { + throw new IOException(s"Illegal schema for XGBoost data, schema=$dataSchema") + } + } + + override def inferSchema( + sparkSession: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + Some( + StructType( + StructField("model_id", StringType, nullable = false) :: + StructField("pred_model", BinaryType, nullable = false) :: Nil) + ) + } + + override def prepareWrite( + sparkSession: SparkSession, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new XGBoostOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + XGBoostOutputWriter.getCompressionExtension(context) + ".xgboost" + } + } + } + + override def buildReader( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + verifySchema(dataSchema) + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + (file: PartitionedFile) => { + val model = new Array[Byte](file.length.asInstanceOf[Int]) + val filePath = new Path(new URI(file.filePath)) + val fs = filePath.getFileSystem(broadcastedHadoopConf.value.value) + + var in: FSDataInputStream = null + try { + in = fs.open(filePath) + IOUtils.readFully(in, model, 0, model.length) + } finally { + IOUtils.closeStream(in) + } + + val converter = RowEncoder(dataSchema) + val fullOutput = dataSchema.map { f => + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + } + val requiredOutput = fullOutput.filter { a => + requiredSchema.fieldNames.contains(a.name) + } + val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) + (requiredColumns( + converter.toRow(Row(new File(file.filePath).getName, model))) + :: Nil + ).toIterator + } + } +} diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala new file mode 100644 index 000000000..a6bbb4b57 --- /dev/null +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/streaming/HivemallStreamingOps.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.ml.feature.HivemallLabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.streaming.dstream.DStream + +final class HivemallStreamingOps(ds: DStream[HivemallLabeledPoint]) { + + def predict[U: ClassTag](f: DataFrame => DataFrame)(implicit sqlContext: SQLContext) + : DStream[Row] = { + ds.transform[Row] { rdd: RDD[HivemallLabeledPoint] => + f(sqlContext.createDataFrame(rdd)).rdd + } + } +} + +object HivemallStreamingOps { + + /** + * Implicitly inject the [[HivemallStreamingOps]] into [[DStream]]. + */ + implicit def dataFrameToHivemallStreamingOps(ds: DStream[HivemallLabeledPoint]) + : HivemallStreamingOps = { + new HivemallStreamingOps(ds) + } +} diff --git a/spark/spark-2.3/src/test/resources/data/files/README.md b/spark/spark-2.3/src/test/resources/data/files/README.md new file mode 100644 index 000000000..238d4721f --- /dev/null +++ b/spark/spark-2.3/src/test/resources/data/files/README.md @@ -0,0 +1,22 @@ + + +The files in this dir exist for preventing exceptions in o.a.s.sql.hive.test.TESTHive. +We need to fix this issue in future. + diff --git a/spark/spark-2.3/src/test/resources/data/files/complex.seq b/spark/spark-2.3/src/test/resources/data/files/complex.seq new file mode 100644 index 000000000..e69de29bb diff --git a/spark/spark-2.3/src/test/resources/data/files/episodes.avro b/spark/spark-2.3/src/test/resources/data/files/episodes.avro new file mode 100644 index 000000000..e69de29bb diff --git a/spark/spark-2.3/src/test/resources/data/files/json.txt b/spark/spark-2.3/src/test/resources/data/files/json.txt new file mode 100644 index 000000000..e69de29bb diff --git a/spark/spark-2.3/src/test/resources/data/files/kv1.txt b/spark/spark-2.3/src/test/resources/data/files/kv1.txt new file mode 100644 index 000000000..e69de29bb diff --git a/spark/spark-2.3/src/test/resources/data/files/kv3.txt b/spark/spark-2.3/src/test/resources/data/files/kv3.txt new file mode 100644 index 000000000..e69de29bb diff --git a/spark/spark-2.3/src/test/resources/log4j.properties b/spark/spark-2.3/src/test/resources/log4j.properties new file mode 100644 index 000000000..c6e4297e1 --- /dev/null +++ b/spark/spark-2.3/src/test/resources/log4j.properties @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the console +log4j.rootCategory=FATAL, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + diff --git a/spark/spark-2.3/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.3/src/test/scala/hivemall/mix/server/MixServerSuite.scala new file mode 100644 index 000000000..9bbd3f0f4 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/hivemall/mix/server/MixServerSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.mix.server + +import java.util.Random +import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.logging.Logger + +import hivemall.mix.MixMessage.MixEventName +import hivemall.mix.client.MixClient +import hivemall.mix.server.MixServer.ServerState +import hivemall.model.{DenseModel, PredictionModel, WeightValue} +import hivemall.utils.io.IOUtils +import hivemall.utils.lang.CommandLineUtils +import hivemall.utils.net.NetUtils +import org.scalatest.{BeforeAndAfter, FunSuite} + +class MixServerSuite extends FunSuite with BeforeAndAfter { + + private[this] var server: MixServer = _ + private[this] var executor : ExecutorService = _ + private[this] var port: Int = _ + + private[this] val rand = new Random(43) + private[this] val counter = Stream.from(0).iterator + + private[this] val eachTestTime = 100 + private[this] val logger = + Logger.getLogger(classOf[MixServerSuite].getName) + + before { + this.port = NetUtils.getAvailablePort + this.server = new MixServer( + CommandLineUtils.parseOptions( + Array("-port", s"${port}", "-sync_threshold", "3"), + MixServer.getOptions() + ) + ) + this.executor = Executors.newSingleThreadExecutor + this.executor.submit(server) + var retry = 0 + while (server.getState() != ServerState.RUNNING && retry < 50) { + Thread.sleep(1000L) + retry += 1 + } + assert(server.getState == ServerState.RUNNING) + } + + after { this.executor.shutdown() } + + private[this] def clientDriver( + groupId: String, model: PredictionModel, numMsg: Int = 1000000): Unit = { + var client: MixClient = null + try { + client = new MixClient(MixEventName.average, groupId, s"localhost:${port}", false, 2, model) + model.configureMix(client, false) + model.configureClock() + + for (_ <- 0 until numMsg) { + val feature = Integer.valueOf(rand.nextInt(model.size)) + model.set(feature, new WeightValue(1.0f)) + } + + while (true) { Thread.sleep(eachTestTime * 1000 + 100L) } + assert(model.getNumMixed > 0) + } finally { + IOUtils.closeQuietly(client) + } + } + + private[this] def fixedGroup: (String, () => String) = + ("fixed", () => "fixed") + private[this] def uniqueGroup: (String, () => String) = + ("unique", () => s"${counter.next}") + + Seq(65536).map { ndims => + Seq(4).map { nclient => + Seq(fixedGroup, uniqueGroup).map { id => + val testName = s"dense-dim:${ndims}-clinet:${nclient}-${id._1}" + ignore(testName) { + val clients = Executors.newCachedThreadPool() + val numClients = nclient + val models = (0 until numClients).map(i => new DenseModel(ndims, false)) + (0 until numClients).map { i => + clients.submit(new Runnable() { + override def run(): Unit = { + try { + clientDriver( + s"${testName}-${id._2}", + models(i) + ) + } catch { + case e: InterruptedException => + assert(false, e.getMessage) + } + } + }) + } + clients.awaitTermination(eachTestTime, TimeUnit.SECONDS) + clients.shutdown() + val nMixes = models.map(d => d.getNumMixed).reduce(_ + _) + logger.info(s"${testName} --> ${(nMixes + 0.0) / eachTestTime} mixes/s") + } + } + } + } +} diff --git a/spark/spark-2.3/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala b/spark/spark-2.3/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala new file mode 100644 index 000000000..c12727610 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/hivemall/tools/RegressionDatagenSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package hivemall.tools + +import org.scalatest.FunSuite + +import org.apache.spark.sql.hive.test.TestHive + +class RegressionDatagenSuite extends FunSuite { + + test("datagen") { + val df = RegressionDatagen.exec( + TestHive, min_examples = 10000, n_features = 100, n_dims = 65536, dense = false, cl = true) + assert(df.count() >= 10000) + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala new file mode 100644 index 000000000..903dc0ae3 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/ml/feature/HivemallLabeledPointSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite + +class HivemallLabeledPointSuite extends SparkFunSuite { + + test("toString") { + val lp = HivemallLabeledPoint(1.0f, Seq("1:0.5", "3:0.3", "8:0.1")) + assert(lp.toString === "1.0,[1:0.5,3:0.3,8:0.1]") + } + + test("parse") { + val lp = HivemallLabeledPoint.parse("1.0,[1:0.5,3:0.3,8:0.1]") + assert(lp.label === 1.0) + assert(lp.features === Seq("1:0.5", "3:0.3", "8:0.1")) + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBaseAccessor.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBaseAccessor.scala new file mode 100644 index 000000000..9e5e2048b --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBaseAccessor.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +// This trait makes `BenchmarkBase` accessible from `o.a.s.sql.hive.benchmark.MiscBenchmark` +private[sql] trait BenchmarkBaseAccessor extends BenchmarkBase diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala new file mode 100644 index 000000000..f16eae0c8 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.HivemallUtils._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.sql.test.VectorQueryTest + +final class HiveUdfWithFeatureSuite extends HivemallFeatureQueryTest { + import hiveContext.implicits._ + import hiveContext._ + + test("hivemall_version") { + sql(s""" + | CREATE TEMPORARY FUNCTION hivemall_version + | AS '${classOf[hivemall.HivemallVersionUDF].getName}' + """.stripMargin) + + checkAnswer( + sql(s"SELECT DISTINCT hivemall_version()"), + Row("0.5.1-incubating-SNAPSHOT") + ) + + // sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version") + // reset() + } + + test("train_logregr") { + TinyTrainData.createOrReplaceTempView("TinyTrainData") + sql(s""" + | CREATE TEMPORARY FUNCTION train_logregr + | AS '${classOf[hivemall.regression.LogressUDTF].getName}' + """.stripMargin) + sql(s""" + | CREATE TEMPORARY FUNCTION add_bias + | AS '${classOf[hivemall.ftvec.AddBiasUDFWrapper].getName}' + """.stripMargin) + + val model = sql( + s""" + | SELECT feature, AVG(weight) AS weight + | FROM ( + | SELECT train_logregr(add_bias(features), label) AS (feature, weight) + | FROM TinyTrainData + | ) t + | GROUP BY feature + """.stripMargin) + + checkAnswer( + model.select($"feature"), + Seq(Row("0"), Row("1"), Row("2")) + ) + + // TODO: Why 'train_logregr' is not registered in HiveMetaStore? + // ERROR RetryingHMSHandler: MetaException(message:NoSuchObjectException + // (message:Function default.train_logregr does not exist)) + // + // hiveContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr") + // hiveContext.reset() + } + + test("each_top_k") { + val testDf = Seq( + ("a", "1", 0.5, Array(0, 1, 2)), + ("b", "5", 0.1, Array(3)), + ("a", "3", 0.8, Array(2, 5)), + ("c", "6", 0.3, Array(1, 3)), + ("b", "4", 0.3, Array(2)), + ("a", "2", 0.6, Array(1)) + ).toDF("key", "value", "score", "data") + + import testDf.sqlContext.implicits._ + testDf.repartition($"key").sortWithinPartitions($"key").createOrReplaceTempView("TestData") + sql(s""" + | CREATE TEMPORARY FUNCTION each_top_k + | AS '${classOf[hivemall.tools.EachTopKUDTF].getName}' + """.stripMargin) + + // Compute top-1 rows for each group + checkAnswer( + sql("SELECT each_top_k(1, key, score, key, value) FROM TestData"), + Row(1, 0.8, "a", "3") :: + Row(1, 0.3, "b", "4") :: + Row(1, 0.3, "c", "6") :: + Nil + ) + + // Compute reverse top-1 rows for each group + checkAnswer( + sql("SELECT each_top_k(-1, key, score, key, value) FROM TestData"), + Row(1, 0.5, "a", "1") :: + Row(1, 0.1, "b", "5") :: + Row(1, 0.3, "c", "6") :: + Nil + ) + } +} + +final class HiveUdfWithVectorSuite extends VectorQueryTest { + import hiveContext._ + + test("to_hivemall_features") { + mllibTrainDf.createOrReplaceTempView("mllibTrainDf") + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) + checkAnswer( + sql( + s""" + | SELECT to_hivemall_features(features) + | FROM mllibTrainDf + """.stripMargin), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + test("append_bias") { + mllibTrainDf.createOrReplaceTempView("mllibTrainDf") + hiveContext.udf.register("append_bias", append_bias_func) + hiveContext.udf.register("to_hivemall_features", to_hivemall_features_func) + checkAnswer( + sql( + s""" + | SELECT to_hivemall_features(append_bias(features)) + | FROM mllibTrainDF + """.stripMargin), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0", "7:1.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2", "7:1.0")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0", "7:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0", "7:1.0")) + ) + ) + } + + ignore("explode_vector") { + // TODO: Spark-2.0 does not support use-defined generator function in + // `org.apache.spark.sql.UDFRegistration`. + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala new file mode 100644 index 000000000..f2b7b6ef2 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -0,0 +1,1393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.HivemallUtils._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.VectorQueryTest +import org.apache.spark.sql.types._ +import org.apache.spark.test.TestFPWrapper._ +import org.apache.spark.test.TestUtils + + +class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { + + test("anomaly") { + import hiveContext.implicits._ + val df = spark.range(1000).selectExpr("id AS time", "rand() AS x") + // TODO: Test results more strictly + assert(df.sort($"time".asc).select(changefinder($"x")).count === 1000) + assert(df.sort($"time".asc).select(sst($"x", lit("-th 0.005"))).count === 1000) + } + + test("knn.similarity") { + import hiveContext.implicits._ + + val df1 = DummyInputData.select( + cosine_similarity(typedLit(Seq(1, 2, 3, 4)), typedLit(Seq(3, 4, 5, 6)))) + val rows1 = df1.collect + assert(rows1.length == 1) + assert(rows1(0).getFloat(0) ~== 0.500f) + + val df2 = DummyInputData.select(jaccard_similarity(lit(5), lit(6))) + val rows2 = df2.collect + assert(rows2.length == 1) + assert(rows2(0).getFloat(0) ~== 0.96875f) + + val df3 = DummyInputData.select( + angular_similarity(typedLit(Seq(1, 2, 3)), typedLit(Seq(4, 5, 6)))) + val rows3 = df3.collect + assert(rows3.length == 1) + assert(rows3(0).getFloat(0) ~== 0.500f) + + val df4 = DummyInputData.select( + euclid_similarity(typedLit(Seq(5, 3, 1)), typedLit(Seq(2, 8, 3)))) + val rows4 = df4.collect + assert(rows4.length == 1) + assert(rows4(0).getFloat(0) ~== 0.33333334f) + + val df5 = DummyInputData.select(distance2similarity(lit(1.0))) + val rows5 = df5.collect + assert(rows5.length == 1) + assert(rows5(0).getFloat(0) ~== 0.5f) + + val df6 = Seq((Seq("1:0.3", "4:0.1"), Map(0 -> 0.5))).toDF("a", "b") + // TODO: Currently, just check if no exception thrown + assert(df6.dimsum_mapper(df6("a"), df6("b")).collect.isEmpty) + } + + test("knn.distance") { + val df1 = DummyInputData.select(hamming_distance(lit(1), lit(3))) + checkAnswer(df1, Row(1)) + + val df2 = DummyInputData.select(popcnt(lit(1))) + checkAnswer(df2, Row(1)) + + val rows3 = DummyInputData.select(kld(lit(0.1), lit(0.5), lit(0.2), lit(0.5))).collect + assert(rows3.length === 1) + assert(rows3(0).getDouble(0) ~== 0.01) + + val rows4 = DummyInputData.select( + euclid_distance(typedLit(Seq("0.1", "0.5")), typedLit(Seq("0.2", "0.5")))).collect + assert(rows4.length === 1) + assert(rows4(0).getFloat(0) ~== 1.4142135f) + + val rows5 = DummyInputData.select( + cosine_distance(typedLit(Seq("0.8", "0.3")), typedLit(Seq("0.4", "0.6")))).collect + assert(rows5.length === 1) + assert(rows5(0).getFloat(0) ~== 1.0f) + + val rows6 = DummyInputData.select( + angular_distance(typedLit(Seq("0.1", "0.1")), typedLit(Seq("0.3", "0.8")))).collect + assert(rows6.length === 1) + assert(rows6(0).getFloat(0) ~== 0.50f) + + val rows7 = DummyInputData.select( + manhattan_distance(typedLit(Seq("0.7", "0.8")), typedLit(Seq("0.5", "0.6")))).collect + assert(rows7.length === 1) + assert(rows7(0).getFloat(0) ~== 4.0f) + + val rows8 = DummyInputData.select( + minkowski_distance(typedLit(Seq("0.1", "0.2")), typedLit(Seq("0.2", "0.2")), typedLit(1.0)) + ).collect + assert(rows8.length === 1) + assert(rows8(0).getFloat(0) ~== 2.0f) + + val rows9 = DummyInputData.select( + jaccard_distance(typedLit(Seq("0.3", "0.8")), typedLit(Seq("0.1", "0.2")))).collect + assert(rows9.length === 1) + assert(rows9(0).getFloat(0) ~== 1.0f) + } + + test("knn.lsh") { + import hiveContext.implicits._ + checkAnswer( + IntList2Data.minhash(lit(1), $"target"), + Row(1016022700, 1) :: + Row(1264890450, 1) :: + Row(1304330069, 1) :: + Row(1321870696, 1) :: + Row(1492709716, 1) :: + Row(1511363108, 1) :: + Row(1601347428, 1) :: + Row(1974434012, 1) :: + Row(2022223284, 1) :: + Row(326269457, 1) :: + Row(50559334, 1) :: + Row(716040854, 1) :: + Row(759249519, 1) :: + Row(809187771, 1) :: + Row(900899651, 1) :: + Nil + ) + checkAnswer( + DummyInputData.select(bbit_minhash(typedLit(Seq("1:0.1", "2:0.5")), lit(false))), + Row("31175986876675838064867796245644543067") + ) + checkAnswer( + DummyInputData.select(minhashes(typedLit(Seq("1:0.1", "2:0.5")), lit(false))), + Row(Seq(1571683640, 987207869, 370931990, 988455638, 846963275)) + ) + } + + test("ftvec - add_bias") { + import hiveContext.implicits._ + checkAnswer(TinyTrainData.select(add_bias($"features")), + Row(Seq("1:0.8", "2:0.2", "0:1.0")) :: + Row(Seq("2:0.7", "0:1.0")) :: + Row(Seq("1:0.9", "0:1.0")) :: + Nil + ) + } + + test("ftvec - extract_feature") { + val df = DummyInputData.select(extract_feature(lit("1:0.8"))) + checkAnswer(df, Row("1")) + } + + test("ftvec - extract_weight") { + val rows = DummyInputData.select(extract_weight(lit("3:0.1"))).collect + assert(rows.length === 1) + assert(rows(0).getDouble(0) ~== 0.1) + } + + test("ftvec - explode_array") { + import hiveContext.implicits._ + val df = TinyTrainData.explode_array($"features").select($"feature") + checkAnswer(df, Row("1:0.8") :: Row("2:0.2") :: Row("2:0.7") :: Row("1:0.9") :: Nil) + } + + test("ftvec - add_feature_index") { + import hiveContext.implicits._ + val doubleListData = Seq(Array(0.8, 0.5), Array(0.3, 0.1), Array(0.2)).toDF("data") + checkAnswer( + doubleListData.select(add_feature_index($"data")), + Row(Seq("1:0.8", "2:0.5")) :: + Row(Seq("1:0.3", "2:0.1")) :: + Row(Seq("1:0.2")) :: + Nil + ) + } + + test("ftvec - sort_by_feature") { + // import hiveContext.implicits._ + val intFloatMapData = { + // TODO: Use `toDF` + val rowRdd = hiveContext.sparkContext.parallelize( + Row(Map(1 -> 0.3f, 2 -> 0.1f, 3 -> 0.5f)) :: + Row(Map(2 -> 0.4f, 1 -> 0.2f)) :: + Row(Map(2 -> 0.4f, 3 -> 0.2f, 1 -> 0.1f, 4 -> 0.6f)) :: + Nil + ) + hiveContext.createDataFrame( + rowRdd, + StructType( + StructField("data", MapType(IntegerType, FloatType), true) :: + Nil) + ) + } + val sortedKeys = intFloatMapData.select(sort_by_feature(intFloatMapData.col("data"))) + .collect.map { + case Row(m: Map[Int, Float]) => m.keysIterator.toSeq + } + assert(sortedKeys.toSet === Set(Seq(1, 2, 3), Seq(1, 2), Seq(1, 2, 3, 4))) + } + + test("ftvec.hash") { + checkAnswer(DummyInputData.select(mhash(lit("test"))), Row(4948445)) + checkAnswer(DummyInputData.select(HivemallOps.sha1(lit("test"))), Row(12184508)) + checkAnswer(DummyInputData.select(feature_hashing(typedLit(Seq("1:0.1", "3:0.5")))), + Row(Seq("11293631:0.1", "4331412:0.5"))) + checkAnswer(DummyInputData.select(array_hash_values(typedLit(Seq("aaa", "bbb")))), + Row(Seq(4063537, 8459207))) + checkAnswer(DummyInputData.select( + prefixed_hash_values(typedLit(Seq("ccc", "ddd")), lit("prefix"))), + Row(Seq("prefix7873825", "prefix8965544"))) + } + + test("ftvec.parting") { + checkAnswer(DummyInputData.select(polynomial_features(typedLit(Seq("2:0.4", "6:0.1")), lit(2))), + Row(Seq("2:0.4", "2^2:0.16000001", "2^6:0.040000003", "6:0.1", "6^6:0.010000001"))) + checkAnswer(DummyInputData.select(powered_features(typedLit(Seq("4:0.8", "5:0.2")), lit(2))), + Row(Seq("4:0.8", "4^2:0.64000005", "5:0.2", "5^2:0.040000003"))) + } + + test("ftvec.scaling") { + val rows1 = TinyTrainData.select(rescale(lit(2.0f), lit(1.0), lit(5.0))).collect + assert(rows1.length === 3) + assert(rows1(0).getFloat(0) ~== 0.25f) + assert(rows1(1).getFloat(0) ~== 0.25f) + assert(rows1(2).getFloat(0) ~== 0.25f) + val rows2 = TinyTrainData.select(zscore(lit(1.0f), lit(0.5), lit(0.5))).collect + assert(rows2.length === 3) + assert(rows2(0).getFloat(0) ~== 1.0f) + assert(rows2(1).getFloat(0) ~== 1.0f) + assert(rows2(2).getFloat(0) ~== 1.0f) + val df3 = TinyTrainData.select(l2_normalize(TinyTrainData.col("features"))) + checkAnswer( + df3, + Row(Seq("1:0.9701425", "2:0.24253562")) :: + Row(Seq("2:1.0")) :: + Row(Seq("1:1.0")) :: + Nil) + } + + test("ftvec.selection - chi2") { + import hiveContext.implicits._ + + // See also hivemall.ftvec.selection.ChiSquareUDFTest + val df = Seq( + Seq( + Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996), + Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3), + Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998) + ) -> Seq( + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589), + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589), + Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589))) + .toDF("arg0", "arg1") + + val rows = df.select(chi2(df("arg0"), df("arg1"))).collect + assert(rows.length == 1) + val chi2Val = rows.head.getAs[Row](0).getAs[Seq[Double]](0) + val pVal = rows.head.getAs[Row](0).getAs[Seq[Double]](1) + + (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("ftvec.conv - quantify") { + import hiveContext.implicits._ + val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF + // This test is done in a single partition because `HivemallOps#quantify` assigns identifiers + // for non-numerical values in each partition. + checkAnswer( + testDf.coalesce(1).quantify(lit(true) +: testDf.cols: _*), + Row(1, 0, 0) :: Row(2, 1, 1) :: Row(3, 0, 1) :: Nil) + } + + test("ftvec.amplify") { + import hiveContext.implicits._ + assert(TinyTrainData.amplify(lit(3), $"label", $"features").count() == 9) + assert(TinyTrainData.part_amplify(lit(3)).count() == 9) + // TODO: The test below failed because: + // java.lang.RuntimeException: Unsupported literal type class scala.Tuple3 + // (-buf 128,label,features) + // + // assert(TinyTrainData.rand_amplify(lit(3), lit("-buf 8", $"label", $"features")).count() == 9) + } + + test("ftvec.conv") { + import hiveContext.implicits._ + + checkAnswer( + DummyInputData.select(to_dense_features(typedLit(Seq("0:0.1", "1:0.3")), lit(1))), + Row(Array(0.1f, 0.3f)) + ) + checkAnswer( + DummyInputData.select(to_sparse_features(typedLit(Seq(0.1f, 0.2f, 0.3f)))), + Row(Seq("0:0.1", "1:0.2", "2:0.3")) + ) + checkAnswer( + DummyInputData.select(feature_binning(typedLit(Seq("1")), typedLit(Map("1" -> Seq(0, 3))))), + Row(Seq("1")) + ) + } + + test("ftvec.trans") { + import hiveContext.implicits._ + + checkAnswer( + DummyInputData.select(vectorize_features(typedLit(Seq("a", "b")), lit(0.1f), lit(0.2f))), + Row(Seq("a:0.1", "b:0.2")) + ) + checkAnswer( + DummyInputData.select(categorical_features(typedLit(Seq("a", "b")), lit("c11"), lit("c12"))), + Row(Seq("a#c11", "b#c12")) + ) + checkAnswer( + DummyInputData.select(indexed_features(lit(0.1), lit(0.2), lit(0.3))), + Row(Seq("1:0.1", "2:0.2", "3:0.3")) + ) + checkAnswer( + DummyInputData.select(quantitative_features(typedLit(Seq("a", "b")), lit(0.1), lit(0.2))), + Row(Seq("a:0.1", "b:0.2")) + ) + checkAnswer( + DummyInputData.select(ffm_features(typedLit(Seq("1", "2")), lit(0.5), lit(0.2))), + Row(Seq("190:140405:1", "111:1058718:1")) + ) + checkAnswer( + DummyInputData.select(add_field_indices(typedLit(Seq("0.5", "0.1")))), + Row(Seq("1:0.5", "2:0.1")) + ) + + val df1 = Seq((1, -3, 1), (2, -2, 1)).toDF("a", "b", "c") + checkAnswer( + df1.binarize_label($"a", $"b", $"c"), + Row(1, 1) :: Row(1, 1) :: Row(1, 1) :: Nil + ) + val df2 = Seq(("xxx", "yyy", 0), ("zzz", "yyy", 1)).toDF("a", "b", "c").coalesce(1) + checkAnswer( + df2.quantified_features(lit(true), df2("a"), df2("b"), df2("c")), + Row(Seq(0.0, 0.0, 0.0)) :: Row(Seq(1.0, 0.0, 1.0)) :: Nil + ) + } + + test("ftvec.ranking") { + import hiveContext.implicits._ + + val df1 = Seq((1, 0 :: 3 :: 4 :: Nil), (2, 8 :: 9 :: Nil)).toDF("a", "b").coalesce(1) + checkAnswer( + df1.bpr_sampling($"a", $"b"), + Row(1, 0, 7) :: + Row(1, 3, 6) :: + Row(2, 8, 0) :: + Row(2, 8, 4) :: + Row(2, 9, 7) :: + Nil + ) + val df2 = Seq(1 :: 8 :: 9 :: Nil, 0 :: 3 :: Nil).toDF("a").coalesce(1) + checkAnswer( + df2.item_pairs_sampling($"a", lit(3)), + Row(0, 1) :: + Row(1, 0) :: + Row(3, 2) :: + Nil + ) + val df3 = Seq(3 :: 5 :: Nil, 0 :: Nil).toDF("a").coalesce(1) + checkAnswer( + df3.populate_not_in($"a", lit(1)), + Row(0) :: + Row(1) :: + Row(1) :: + Nil + ) + } + + test("tools") { + // checkAnswer( + // DummyInputData.select(convert_label(lit(5))), + // Nil + // ) + checkAnswer( + DummyInputData.select(x_rank(lit("abc"))), + Row(1) + ) + } + + test("tools.array") { + checkAnswer( + DummyInputData.select(float_array(lit(3))), + Row(Seq()) + ) + checkAnswer( + DummyInputData.select(array_remove(typedLit(Seq(1, 2, 3)), lit(2))), + Row(Seq(1, 3)) + ) + checkAnswer( + DummyInputData.select(sort_and_uniq_array(typedLit(Seq(2, 1, 3, 1)))), + Row(Seq(1, 2, 3)) + ) + checkAnswer( + DummyInputData.select(subarray_endwith(typedLit(Seq(1, 2, 3, 4, 5)), lit(4))), + Row(Seq(1, 2, 3, 4)) + ) + checkAnswer( + DummyInputData.select( + array_concat(typedLit(Seq(1, 2)), typedLit(Seq(3)), typedLit(Seq(4, 5)))), + Row(Seq(1, 2, 3, 4, 5)) + ) + checkAnswer( + DummyInputData.select(subarray(typedLit(Seq(1, 2, 3, 4, 5)), lit(2), lit(4))), + Row(Seq(3, 4)) + ) + checkAnswer( + DummyInputData.select(to_string_array(typedLit(Seq(1, 2, 3, 4, 5)))), + Row(Seq("1", "2", "3", "4", "5")) + ) + checkAnswer( + DummyInputData.select(array_intersect(typedLit(Seq(1, 2, 3)), typedLit(Seq(2, 3, 4)))), + Row(Seq(2, 3)) + ) + } + + test("tools.array - select_k_best") { + import hiveContext.implicits._ + + val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9)) + val df = data.map(d => (d, Seq(3, 1, 2))).toDF("features", "importance_list") + val k = 2 + + checkAnswer( + df.select(select_k_best(df("features"), df("importance_list"), lit(k))), + Row(Seq(0.0, 3.0)) :: Row(Seq(2.0, 1.0)) :: Row(Seq(5.0, 9.0)) :: Nil + ) + } + + test("tools.bits") { + checkAnswer( + DummyInputData.select(to_bits(typedLit(Seq(1, 3, 9)))), + Row(Seq(522L)) + ) + checkAnswer( + DummyInputData.select(unbits(typedLit(Seq(1L, 3L)))), + Row(Seq(0L, 64L, 65L)) + ) + checkAnswer( + DummyInputData.select(bits_or(typedLit(Seq(1L, 3L)), typedLit(Seq(8L, 23L)))), + Row(Seq(9L, 23L)) + ) + } + + test("tools.compress") { + checkAnswer( + DummyInputData.select(inflate(deflate(lit("input text")))), + Row("input text") + ) + } + + test("tools.map") { + val rows = DummyInputData.select( + map_get_sum(typedLit(Map(1 -> 0.2f, 2 -> 0.5f, 4 -> 0.8f)), typedLit(Seq(1, 4))) + ).collect + assert(rows.length === 1) + assert(rows(0).getDouble(0) ~== 1.0f) + + checkAnswer( + DummyInputData.select(map_tail_n(typedLit(Map(1 -> 2, 2 -> 5)), lit(1))), + Row(Map(2 -> 5)) + ) + } + + test("tools.text") { + checkAnswer( + DummyInputData.select(tokenize(lit("This is a pen"))), + Row("This" :: "is" :: "a" :: "pen" :: Nil) + ) + checkAnswer( + DummyInputData.select(is_stopword(lit("because"))), + Row(true) + ) + checkAnswer( + DummyInputData.select(singularize(lit("between"))), + Row("between") + ) + checkAnswer( + DummyInputData.select(split_words(lit("Hello, world"))), + Row("Hello," :: "world" :: Nil) + ) + checkAnswer( + DummyInputData.select(normalize_unicode(lit("abcdefg"))), + Row("abcdefg") + ) + checkAnswer( + DummyInputData.select(base91(typedLit("input text".getBytes))), + Row("xojg[@TX;R..B") + ) + checkAnswer( + DummyInputData.select(unbase91(lit("XXXX"))), + Row(68 :: -120 :: 8 :: Nil) + ) + checkAnswer( + DummyInputData.select(word_ngrams(typedLit("abcd" :: "efg" :: "hij" :: Nil), lit(2), lit(2))), + Row("abcd efg" :: "efg hij" :: Nil) + ) + } + + test("tools - generated_series") { + checkAnswer( + DummyInputData.generate_series(lit(0), lit(3)), + Row(0) :: Row(1) :: Row(2) :: Row(3) :: Nil + ) + } + + test("geospatial") { + val rows1 = DummyInputData.select(tilex2lon(lit(1), lit(6))).collect + assert(rows1.length === 1) + assert(rows1(0).getDouble(0) ~== -174.375) + + val rows2 = DummyInputData.select(tiley2lat(lit(1), lit(3))).collect + assert(rows2.length === 1) + assert(rows2(0).getDouble(0) ~== 79.17133464081945) + + val rows3 = DummyInputData.select( + haversine_distance(lit(0.3), lit(0.1), lit(0.4), lit(0.1))).collect + assert(rows3.length === 1) + assert(rows3(0).getDouble(0) ~== 11.119492664455878) + + checkAnswer( + DummyInputData.select(tile(lit(0.1), lit(0.8), lit(3))), + Row(28) + ) + checkAnswer( + DummyInputData.select(map_url(lit(0.1), lit(0.8), lit(3))), + Row("http://tile.openstreetmap.org/3/4/3.png") + ) + checkAnswer( + DummyInputData.select(lat2tiley(lit(0.3), lit(3))), + Row(3) + ) + checkAnswer( + DummyInputData.select(lon2tilex(lit(0.4), lit(2))), + Row(2) + ) + } + + test("misc - hivemall_version") { + checkAnswer(DummyInputData.select(hivemall_version()), Row("0.5.1-incubating-SNAPSHOT")) + } + + test("misc - rowid") { + assert(DummyInputData.select(rowid()).distinct.count == DummyInputData.count) + } + + test("misc - each_top_k") { + import hiveContext.implicits._ + val inputDf = Seq( + ("a", "1", 0.5, 0.1, Array(0, 1, 2)), + ("b", "5", 0.1, 0.2, Array(3)), + ("a", "3", 0.8, 0.8, Array(2, 5)), + ("c", "6", 0.3, 0.3, Array(1, 3)), + ("b", "4", 0.3, 0.4, Array(2)), + ("a", "2", 0.6, 0.5, Array(1)) + ).toDF("key", "value", "x", "y", "data") + + // Compute top-1 rows for each group + val distance = sqrt(inputDf("x") * inputDf("x") + inputDf("y") * inputDf("y")).as("score") + val top1Df = inputDf.each_top_k(lit(1), distance, $"key".as("group")) + assert(top1Df.schema.toSet === Set( + StructField("rank", IntegerType, nullable = true), + StructField("score", DoubleType, nullable = true), + StructField("key", StringType, nullable = true), + StructField("value", StringType, nullable = true), + StructField("x", DoubleType, nullable = true), + StructField("y", DoubleType, nullable = true), + StructField("data", ArrayType(IntegerType, containsNull = false), nullable = true) + )) + checkAnswer( + top1Df.select($"rank", $"key", $"value", $"data"), + Row(1, "a", "3", Array(2, 5)) :: + Row(1, "b", "4", Array(2)) :: + Row(1, "c", "6", Array(1, 3)) :: + Nil + ) + + // Compute reverse top-1 rows for each group + val bottom1Df = inputDf.each_top_k(lit(-1), distance, $"key".as("group")) + checkAnswer( + bottom1Df.select($"rank", $"key", $"value", $"data"), + Row(1, "a", "1", Array(0, 1, 2)) :: + Row(1, "b", "5", Array(3)) :: + Row(1, "c", "6", Array(1, 3)) :: + Nil + ) + + // Check if some exceptions thrown in case of some conditions + assert(intercept[AnalysisException] { inputDf.each_top_k(lit(0.1), $"score", $"key") } + .getMessage contains "`k` must be integer, however") + assert(intercept[AnalysisException] { inputDf.each_top_k(lit(1), $"data", $"key") } + .getMessage contains "must have a comparable type") + } + + test("misc - join_top_k") { + Seq("true", "false").map { flag => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) { + import hiveContext.implicits._ + val inputDf = Seq( + ("user1", 1, 0.3, 0.5), + ("user2", 2, 0.1, 0.1), + ("user3", 3, 0.8, 0.0), + ("user4", 1, 0.9, 0.9), + ("user5", 3, 0.7, 0.2), + ("user6", 1, 0.5, 0.4), + ("user7", 2, 0.6, 0.8) + ).toDF("userId", "group", "x", "y") + + val masterDf = Seq( + (1, "pos1-1", 0.5, 0.1), + (1, "pos1-2", 0.0, 0.0), + (1, "pos1-3", 0.3, 0.3), + (2, "pos2-3", 0.1, 0.3), + (2, "pos2-3", 0.8, 0.8), + (3, "pos3-1", 0.1, 0.7), + (3, "pos3-1", 0.7, 0.1), + (3, "pos3-1", 0.9, 0.0), + (3, "pos3-1", 0.1, 0.3) + ).toDF("group", "position", "x", "y") + + // Compute top-1 rows for each group + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + val top1Df = inputDf.top_k_join( + lit(1), masterDf, inputDf("group") === masterDf("group"), distance) + assert(top1Df.schema.toSet === Set( + StructField("rank", IntegerType, nullable = true), + StructField("score", DoubleType, nullable = true), + StructField("group", IntegerType, nullable = false), + StructField("userId", StringType, nullable = true), + StructField("position", StringType, nullable = true), + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType, nullable = false) + )) + checkAnswer( + top1Df.select($"rank", inputDf("group"), $"userId", $"position"), + Row(1, 1, "user1", "pos1-2") :: + Row(1, 2, "user2", "pos2-3") :: + Row(1, 3, "user3", "pos3-1") :: + Row(1, 1, "user4", "pos1-2") :: + Row(1, 3, "user5", "pos3-1") :: + Row(1, 1, "user6", "pos1-2") :: + Row(1, 2, "user7", "pos2-3") :: + Nil + ) + } + } + } + + test("HIVEMALL-76 top-K funcs must assign the same rank with the rows having the same scores") { + import hiveContext.implicits._ + val inputDf = Seq( + ("a", "1", 0.1), + ("b", "5", 0.1), + ("a", "3", 0.1), + ("b", "4", 0.1), + ("a", "2", 0.0) + ).toDF("key", "value", "x") + + // Compute top-2 rows for each group + val top2Df = inputDf.each_top_k(lit(2), $"x".as("score"), $"key".as("group")) + checkAnswer( + top2Df.select($"rank", $"score", $"key", $"value"), + Row(1, 0.1, "a", "3") :: + Row(1, 0.1, "a", "1") :: + Row(1, 0.1, "b", "4") :: + Row(1, 0.1, "b", "5") :: + Nil + ) + Seq("true", "false").map { flag => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> flag) { + val inputDf = Seq( + ("user1", 1, 0.3, 0.5), + ("user2", 2, 0.1, 0.1) + ).toDF("userId", "group", "x", "y") + + val masterDf = Seq( + (1, "pos1-1", 0.5, 0.1), + (1, "pos1-2", 0.5, 0.1), + (1, "pos1-3", 0.3, 0.4), + (2, "pos2-1", 0.8, 0.2), + (2, "pos2-2", 0.8, 0.2) + ).toDF("group", "position", "x", "y") + + // Compute top-2 rows for each group + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + val top2Df = inputDf.top_k_join( + lit(2), masterDf, inputDf("group") === masterDf("group"), distance) + checkAnswer( + top2Df.select($"rank", inputDf("group"), $"userId", $"position"), + Row(1, 1, "user1", "pos1-1") :: + Row(1, 1, "user1", "pos1-2") :: + Row(1, 2, "user2", "pos2-1") :: + Row(1, 2, "user2", "pos2-2") :: + Nil + ) + } + } + } + + test("misc - flatten") { + import hiveContext.implicits._ + val df = Seq((0, (1, "a", (3.0, "b")), (5, 0.9, "c", "d"), 9)).toDF() + assert(df.flatten().schema === StructType( + StructField("_1", IntegerType, nullable = false) :: + StructField("_2$_1", IntegerType, nullable = true) :: + StructField("_2$_2", StringType, nullable = true) :: + StructField("_2$_3$_1", DoubleType, nullable = true) :: + StructField("_2$_3$_2", StringType, nullable = true) :: + StructField("_3$_1", IntegerType, nullable = true) :: + StructField("_3$_2", DoubleType, nullable = true) :: + StructField("_3$_3", StringType, nullable = true) :: + StructField("_3$_4", StringType, nullable = true) :: + StructField("_4", IntegerType, nullable = false) :: + Nil + )) + checkAnswer(df.flatten("$").select("_2$_1"), Row(1)) + checkAnswer(df.flatten("_").select("_2__1"), Row(1)) + checkAnswer(df.flatten(".").select("`_2._1`"), Row(1)) + + val errMsg1 = intercept[IllegalArgumentException] { df.flatten("\t") } + assert(errMsg1.getMessage.startsWith("Must use '$', '_', or '.' for separator, but got")) + val errMsg2 = intercept[IllegalArgumentException] { df.flatten("12") } + assert(errMsg2.getMessage.startsWith("Separator cannot be more than one character:")) + } + + test("misc - from_csv") { + import hiveContext.implicits._ + val df = Seq("""1,abc""").toDF() + val schema = new StructType().add("a", IntegerType).add("b", StringType) + checkAnswer( + df.select(from_csv($"value", schema)), + Row(Row(1, "abc"))) + } + + test("misc - to_csv") { + import hiveContext.implicits._ + val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF() + checkAnswer( + df.select(to_csv($"_3")), + Row("0,3.9,abc") :: + Row("2,0.4,def") :: + Nil) + } + + /** + * This test fails because; + * + * Cause: java.lang.OutOfMemoryError: Java heap space + * at hivemall.smile.tools.RandomForestEnsembleUDAF$Result. + * (RandomForestEnsembleUDAF.java:128) + * at hivemall.smile.tools.RandomForestEnsembleUDAF$RandomForestPredictUDAFEvaluator + * .terminate(RandomForestEnsembleUDAF.java:91) + * at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) + */ + ignore("misc - tree_predict") { + import hiveContext.implicits._ + + val model = Seq((0.0, 0.1 :: 0.1 :: Nil), (1.0, 0.2 :: 0.3 :: 0.2 :: Nil)) + .toDF("label", "features") + .train_randomforest_regressor($"features", $"label") + + val testData = Seq((0.0, 0.1 :: 0.0 :: Nil), (1.0, 0.3 :: 0.5 :: 0.4 :: Nil)) + .toDF("label", "features") + .select(rowid(), $"label", $"features") + + val predicted = model + .join(testData).coalesce(1) + .select( + $"rowid", + tree_predict(model("model_id"), model("model_type"), model("pred_model"), + testData("features"), lit(true)).as("predicted") + ) + .groupBy($"rowid") + .rf_ensemble("predicted").toDF("rowid", "predicted") + .select($"predicted.label") + + checkAnswer(predicted, Seq(Row(0), Row(1))) + } + + test("misc - sigmoid") { + import hiveContext.implicits._ + val rows = DummyInputData.select(sigmoid($"c0")).collect + assert(rows.length === 1) + assert(rows(0).getDouble(0) ~== 0.500) + } + + test("misc - lr_datagen") { + assert(TinyTrainData.lr_datagen(lit("-n_examples 100 -n_features 10 -seed 100")).count >= 100) + } + + test("invoke regression functions") { + import hiveContext.implicits._ + Seq( + "train_regressor", + "train_adadelta_regr", + "train_adagrad_regr", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logistic_regr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + // "train_randomforest_regressor" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + test("invoke classifier functions") { + import hiveContext.implicits._ + Seq( + "train_classifier", + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + // "train_randomforest_classifier" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(TinyTrainData), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + test("invoke multiclass classifier functions") { + import hiveContext.implicits._ + Seq( + "train_multiclass_perceptron", + "train_multiclass_pa", + "train_multiclass_pa1", + "train_multiclass_pa2", + "train_multiclass_cw", + "train_multiclass_arow", + "train_multiclass_arowh", + "train_multiclass_scw", + "train_multiclass_scw2" + ).map { func => + // TODO: Why is a label type [Int|Text] only in multiclass classifiers? + TestUtils.invokeFunc( + new HivemallOps(TinyTrainData), func, Seq($"features", $"label".cast(IntegerType))) + .foreach(_ => {}) // Just call it + } + } + + test("invoke random forest functions") { + import hiveContext.implicits._ + val testDf = Seq( + (Array(0.3, 0.1, 0.2), 1), + (Array(0.3, 0.1, 0.2), 0), + (Array(0.3, 0.1, 0.2), 0)).toDF("features", "label") + Seq( + "train_randomforest_regressor", + "train_randomforest_classifier" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(testDf.coalesce(1)), func, Seq($"features", $"label")) + .foreach(_ => {}) // Just call it + } + } + + test("invoke recommend functions") { + import hiveContext.implicits._ + val df = Seq((1, Map(1 -> 0.3), Map(2 -> Map(4 -> 0.1)), 0, Map(3 -> 0.5))) + .toDF("i", "r_i", "topKRatesOfI", "j", "r_j") + // Just call it + df.train_slim($"i", $"r_i", $"topKRatesOfI", $"j", $"r_j").collect + + } + + ignore("invoke topicmodel functions") { + import hiveContext.implicits._ + val testDf = Seq(Seq("abcd", "'efghij", "klmn")).toDF("words") + Seq( + "train_lda", + "train_plsa" + ).map { func => + TestUtils.invokeFunc(new HivemallOps(testDf.coalesce(1)), func, Seq($"words")) + .foreach(_ => {}) // Just call it + } + } + + protected def checkRegrPrecision(func: String): Unit = { + import hiveContext.implicits._ + + // Build a model + val model = { + val res = TestUtils.invokeFunc(new HivemallOps(LargeRegrTrainData), + func, Seq(add_bias($"features"), $"label")) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = LargeRegrTrainData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .groupBy() + .agg(Map("target" -> "avg", "predicted" -> "avg")) + .toDF("target", "predicted") + + val diff = eval.map { + case Row(target: Double, predicted: Double) => + Math.abs(target - predicted) + }.first + + TestUtils.expectResult(diff > 0.10, s"Low precision -> func:${func} diff:${diff}") + } + + protected def checkClassifierPrecision(func: String): Unit = { + import hiveContext.implicits._ + + // Build a model + val model = { + val res = TestUtils.invokeFunc(new HivemallOps(LargeClassifierTrainData), + func, Seq(add_bias($"features"), $"label")) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = LargeClassifierTestData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + /** + * TODO: This sentence throws an exception below: + * + * WARN Column: Constructing trivially true equals predicate, 'rowid#1323 = rowid#1323'. + * Perhaps you need to use aliases. + */ + .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0)) + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .where($"target" === $"predicted") + + val precision = (eval.count + 0.0) / predict.count + + TestUtils.expectResult(precision < 0.70, s"Low precision -> func:${func} value:${precision}") + } + + ignore("check regression precision") { + Seq( + "train_adadelta_regr", + "train_adagrad_regr", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logistic_regr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + checkRegrPrecision(func) + } + } + + ignore("check classifier precision") { + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + checkClassifierPrecision(func) + } + } + + test("aggregations for classifiers") { + import hiveContext.implicits._ + val df1 = Seq((1, 0.1, 0.1, 0.2f, 0.2f, 0.2f, 0.2f)) + .toDF("key", "xh", "xk", "w0", "w1", "w2", "w3") + val row1 = df1.groupBy($"key").kpa_predict("xh", "xk", "w0", "w1", "w2", "w3").collect + assert(row1.length === 1) + assert(row1(0).getDouble(1) ~== 0.002000000029802) + } + + test("aggregations for ensembles") { + import hiveContext.implicits._ + + val df1 = Seq((1, 0.1), (1, 0.2), (2, 0.1)).toDF("c0", "c1") + val rows1 = df1.groupBy($"c0").voted_avg("c1").collect + assert(rows1.length === 2) + assert(rows1(0).getDouble(1) ~== 0.15) + assert(rows1(1).getDouble(1) ~== 0.10) + + val df3 = Seq((1, 0.2), (1, 0.8), (2, 0.3)).toDF("c0", "c1") + val rows3 = df3.groupBy($"c0").weight_voted_avg("c1").collect + assert(rows3.length === 2) + assert(rows3(0).getDouble(1) ~== 0.50) + assert(rows3(1).getDouble(1) ~== 0.30) + + val df5 = Seq((1, 0.2f, 0.1f), (1, 0.4f, 0.2f), (2, 0.8f, 0.9f)).toDF("c0", "c1", "c2") + val rows5 = df5.groupBy($"c0").argmin_kld("c1", "c2").collect + assert(rows5.length === 2) + assert(rows5(0).getFloat(1) ~== 0.266666666) + assert(rows5(1).getFloat(1) ~== 0.80) + + val df6 = Seq((1, "id-0", 0.2), (1, "id-1", 0.4), (1, "id-2", 0.1)).toDF("c0", "c1", "c2") + val rows6 = df6.groupBy($"c0").max_label("c2", "c1").collect + assert(rows6.length === 1) + assert(rows6(0).getString(1) == "id-1") + + val df7 = Seq((1, "id-0", 0.5), (1, "id-1", 0.1), (1, "id-2", 0.2)).toDF("c0", "c1", "c2") + val rows7 = df7.groupBy($"c0").maxrow("c2", "c1").toDF("c0", "c1").select($"c1.col1").collect + assert(rows7.length === 1) + assert(rows7(0).getString(0) == "id-0") + + val df8 = Seq((1, 1), (1, 2), (2, 1), (1, 5)).toDF("c0", "c1") + val rows8 = df8.groupBy($"c0").rf_ensemble("c1").toDF("c0", "c1") + .select("c1.probability").collect + assert(rows8.length === 2) + assert(rows8(0).getDouble(0) ~== 0.3333333333) + assert(rows8(1).getDouble(0) ~== 1.0) + } + + test("aggregations for evaluation") { + import hiveContext.implicits._ + + val testDf1 = Seq((1, 1.0, 0.5), (1, 0.3, 0.5), (1, 0.1, 0.2)).toDF("c0", "c1", "c2") + val rows1 = testDf1.groupBy($"c0").mae("c1", "c2").collect + assert(rows1.length === 1) + assert(rows1(0).getDouble(1) ~== 0.26666666) + val rows2 = testDf1.groupBy($"c0").mse("c1", "c2").collect + assert(rows2.length === 1) + assert(rows2(0).getDouble(1) ~== 0.1) + val rows3 = testDf1.groupBy($"c0").rmse("c1", "c2").collect + assert(rows3.length === 1) + assert(rows3(0).getDouble(1) ~== 0.31622776601683794) + val rows4 = testDf1.groupBy($"c0").r2("c1", "c2").collect + assert(rows4.length === 1) + assert(rows4(0).getDouble(1) ~== -4.0) + val rows5 = testDf1.groupBy($"c0").logloss("c1", "c2").collect + assert(rows5.length === 1) + assert(rows5(0).getDouble(1) ~== 6.198305767142615) + + val testDf2 = Seq((1, Array(1, 2), Array(2, 3)), (1, Array(3, 8), Array(5, 4))) + .toDF("c0", "c1", "c2") + val rows6 = testDf2.groupBy($"c0").ndcg("c1", "c2").collect + assert(rows6.length === 1) + assert(rows6(0).getDouble(1) ~== 0.19342640361727081) + val rows7 = testDf2.groupBy($"c0").precision_at("c1", "c2").collect + assert(rows7.length === 1) + assert(rows7(0).getDouble(1) ~== 0.25) + val rows8 = testDf2.groupBy($"c0").recall_at("c1", "c2").collect + assert(rows8.length === 1) + assert(rows8(0).getDouble(1) ~== 0.25) + val rows9 = testDf2.groupBy($"c0").hitrate("c1", "c2").collect + assert(rows9.length === 1) + assert(rows9(0).getDouble(1) ~== 0.50) + val rows10 = testDf2.groupBy($"c0").mrr("c1", "c2").collect + assert(rows10.length === 1) + assert(rows10(0).getDouble(1) ~== 0.25) + val rows11 = testDf2.groupBy($"c0").average_precision("c1", "c2").collect + assert(rows11.length === 1) + assert(rows11(0).getDouble(1) ~== 0.25) + val rows12 = testDf2.groupBy($"c0").auc("c1", "c2").collect + assert(rows12.length === 1) + assert(rows12(0).getDouble(1) ~== 0.25) + } + + test("aggregations for topicmodel") { + import hiveContext.implicits._ + + val testDf = Seq((1, "abcd", 0.1, 0, 0.1), (1, "efgh", 0.2, 0, 0.1)) + .toDF("key", "word", "value", "label", "lambda") + val rows1 = testDf.groupBy($"key").lda_predict("word", "value", "label", "lambda").collect + assert(rows1.length === 1) + val result1 = rows1(0).getSeq[Row](1).map { case Row(label: Int, prob: Float) => label -> prob } + .toMap[Int, Float] + assert(result1.size === 10) + assert(result1(0) ~== 0.07692449) + assert(result1(1) ~== 0.07701121) + assert(result1(2) ~== 0.07701129) + assert(result1(3) ~== 0.07705542) + assert(result1(4) ~== 0.07701511) + assert(result1(5) ~== 0.07701234) + assert(result1(6) ~== 0.07701384) + assert(result1(7) ~== 0.30693996) + assert(result1(8) ~== 0.07700701) + assert(result1(9) ~== 0.07700934) + + val rows2 = testDf.groupBy($"key").plsa_predict("word", "value", "label", "lambda").collect + assert(rows2.length === 1) + val result2 = rows2(0).getSeq[Row](1).map { case Row(label: Int, prob: Float) => label -> prob } + .toMap[Int, Float] + assert(result2.size === 10) + assert(result2(0) ~== 0.062156882) + assert(result2(1) ~== 0.05088547) + assert(result2(2) ~== 0.12434204) + assert(result2(3) ~== 0.31869823) + assert(result2(4) ~== 0.01584355) + assert(result2(5) ~== 0.0057667173) + assert(result2(6) ~== 0.10864779) + assert(result2(7) ~== 0.09346106) + assert(result2(8) ~== 0.13905199) + assert(result2(9) ~== 0.081146255) + } + + test("aggregations for ftvec.text") { + import hiveContext.implicits._ + val testDf = Seq((1, "abc def hi jk l"), (1, "def jk")).toDF("key", "text") + val rows = testDf.groupBy($"key").tf("text").collect + assert(rows.length === 1) + val result = rows(0).getAs[Map[String, Float]](1) + assert(result.size === 2) + assert(result("def jk") ~== 0.5f) + assert(result("abc def hi jk l") ~== 0.5f) + } + + test("aggregations for tools.array") { + import hiveContext.implicits._ + + val testDf = Seq((1, 1 :: 3 :: Nil), (1, 3 :: 5 :: Nil)).toDF("key", "ar") + val rows1 = testDf.groupBy($"key").array_avg("ar").collect + assert(rows1.length === 1) + val result1 = rows1(0).getSeq[Float](1) + assert(result1.length === 2) + assert(result1(0) ~== 2.0f) + assert(result1(1) ~== 4.0f) + + val rows2 = testDf.groupBy($"key").array_sum("ar").collect + assert(rows2.length === 1) + val result2 = rows2(0).getSeq[Double](1) + assert(result2.length === 2) + assert(result2(0) ~== 4.0) + assert(result2(1) ~== 8.0) + } + + test("aggregations for tools.bits") { + import hiveContext.implicits._ + val testDf = Seq((1, 1), (1, 7)).toDF("key", "x") + val rows = testDf.groupBy($"key").bits_collect("x").collect + assert(rows.length === 1) + val result = rows(0).getSeq[Int](1) + assert(result === Seq(130)) + } + + test("aggregations for tools.list") { + import hiveContext.implicits._ + val testDf = Seq((1, 3), (1, 1), (1, 2)).toDF("key", "x") + val rows = testDf.groupBy($"key").to_ordered_list("x").collect + assert(rows.length === 1) + val result = rows(0).getSeq[Int](1) + assert(result === Seq(1, 2, 3)) + } + + test("aggregations for tools.map") { + import hiveContext.implicits._ + val testDf = Seq((1, 1, "a"), (1, 2, "b"), (1, 3, "c")).toDF("key", "k", "v") + val rows = testDf.groupBy($"key").to_map("k", "v").collect + assert(rows.length === 1) + val result = rows(0).getMap[Int, String](1) + assert(result === Map(1 -> "a", 2 -> "b", 3 -> "c")) + } + + test("aggregations for tools.math") { + import hiveContext.implicits._ + val testDf = Seq( + (1, Seq(1, 2, 3, 4), Seq(5, 6, 7, 8)), + (1, Seq(9, 10, 11, 12), Seq(13, 14, 15, 16)) + ).toDF("key", "mtx1", "mtx2") + val rows = testDf.groupBy($"key").transpose_and_dot("mtx1", "mtx2").collect + assert(rows.length === 1) + val result = rows(0).getSeq[Int](1) + assert(result === Seq( + Seq(122.0, 132.0, 142.0, 152.0), + Seq(140.0, 152.0, 164.0, 176.0), + Seq(158.0, 172.0, 186.0, 200.0), + Seq(176.0, 192.0, 208.0, 224.0)) + ) + } + + test("aggregations for ftvec.trans") { + import hiveContext.implicits._ + + val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10), + (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9), + (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9)) + .toDF("col0", "cat1", "cat2", "cat3") + val row00 = df0.groupBy($"col0").onehot_encoding("cat1") + val row01 = df0.groupBy($"col0").onehot_encoding("cat1", "cat2", "cat3") + + val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0) + val result01 = row01.collect()(0).getAs[Row](1) + val result010 = result01.getAs[Map[String, Int]](0) + val result011 = result01.getAs[Map[String, Int]](1) + val result012 = result01.getAs[Map[String, Int]](2) + + assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result000.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result010.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result011.keySet === Set("bird", "insect", "mammal")) + assert(result011.values.toSet === Set(6, 7, 8)) + assert(result012.keySet === Set(1, 3, 9, 10, 101)) + assert(result012.values.toSet === Set(9, 10, 11, 12, 13)) + } + + test("aggregations for ftvec.selection") { + import hiveContext.implicits._ + + // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest + // binary class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 4.7,3.2,1.3,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.9,3.1,4.9,1.5 | 1 | + // +-----------------+-------+ + val df0 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)), + (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)), + (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1))) + .toDF("c0", "arg0", "arg1") + val row0 = df0.groupBy($"c0").snr("arg0", "arg1").collect + (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + // multiple class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.3,3.3,6.0,2.5 | 2 | + // | 5.8,2.7,5.1,1.9 | 2 | + // +-----------------+-------+ + val df1 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)), + (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)), + (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1))) + .toDF("c0", "arg0", "arg1") + val row1 = df1.groupBy($"c0").snr("arg0", "arg1").collect + (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("aggregations for tools.matrix") { + import hiveContext.implicits._ + + // | 1 2 3 |T | 5 6 7 | + // | 3 4 5 | * | 7 8 9 | + val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))) + .toDF("c0", "arg0", "arg1") + + checkAnswer(df0.groupBy($"c0").transpose_and_dot("arg0", "arg1"), + Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))) + } +} + +final class HivemallOpsWithVectorSuite extends VectorQueryTest { + import hiveContext.implicits._ + + test("to_hivemall_features") { + checkAnswer( + mllibTrainDf.select(to_hivemall_features($"features")), + Seq( + Row(Seq("0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + ignore("append_bias") { + /** + * TODO: This test throws an exception: + * Failed to analyze query: org.apache.spark.sql.AnalysisException: cannot resolve + * 'UDF(UDF(features))' due to data type mismatch: argument 1 requires vector type, + * however, 'UDF(features)' is of vector type.; line 2 pos 8 + */ + checkAnswer( + mllibTrainDf.select(to_hivemall_features(append_bias($"features"))), + Seq( + Row(Seq("0:1.0", "0:1.0", "2:2.0", "4:3.0")), + Row(Seq("0:1.0", "0:1.0", "3:1.5", "4:2.1", "6:1.2")), + Row(Seq("0:1.0", "0:1.1", "3:1.0", "4:2.3", "6:1.0")), + Row(Seq("0:1.0", "1:4.0", "3:5.0", "5:6.0")) + ) + ) + } + + test("explode_vector") { + checkAnswer( + mllibTrainDf.explode_vector($"features").select($"feature", $"weight"), + Seq( + Row("0", 1.0), Row("0", 1.0), Row("0", 1.1), + Row("1", 4.0), + Row("2", 2.0), + Row("3", 1.0), Row("3", 1.5), Row("3", 5.0), + Row("4", 2.1), Row("4", 2.3), Row("4", 3.0), + Row("5", 6.0), + Row("6", 1.0), Row("6", 1.2) + ) + ) + } + + test("train_logistic_regr") { + checkAnswer( + mllibTrainDf.train_logistic_regr($"features", $"label") + .groupBy("feature").agg("weight" -> "avg") + .select($"feature"), + Seq(0, 1, 2, 3, 4, 5, 6).map(v => Row(s"$v")) + ) + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala new file mode 100644 index 000000000..267179882 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.{BufferedInputStream, InputStream, InputStreamReader, BufferedReader} +import java.net.URL +import java.util.UUID +import java.util.concurrent.{Executors, ExecutorService} + +import hivemall.mix.server.MixServer +import hivemall.utils.lang.CommandLineUtils +import hivemall.utils.net.NetUtils +import org.apache.commons.cli.Options +import org.apache.commons.compress.compressors.CompressorStreamFactory +import org.apache.spark.SparkFunSuite +import org.scalatest.BeforeAndAfter +import org.apache.spark.ml.feature.HivemallLabeledPoint +import org.apache.spark.sql.{DataFrame, Row, Column} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.test.TestUtils + +final class ModelMixingSuite extends SparkFunSuite with BeforeAndAfter { + + // Load A9a training and test data + val a9aLineParser = (line: String) => { + val elements = line.split(" ") + val (label, features) = (elements.head, elements.tail) + HivemallLabeledPoint(if (label == "+1") 1.0f else 0.0f, features) + } + + lazy val trainA9aData: DataFrame = + getDataFromURI( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a").openStream, + a9aLineParser) + + lazy val testA9aData: DataFrame = + getDataFromURI( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a.t").openStream, + a9aLineParser) + + // Load A9a training and test data + val kdd2010aLineParser = (line: String) => { + val elements = line.split(" ") + val (label, features) = (elements.head, elements.tail) + HivemallLabeledPoint(if (label == "1") 1.0f else 0.0f, features) + } + + lazy val trainKdd2010aData: DataFrame = + getDataFromURI( + new CompressorStreamFactory().createCompressorInputStream( + new BufferedInputStream( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.bz2") + .openStream + ) + ), + kdd2010aLineParser, + 8) + + lazy val testKdd2010aData: DataFrame = + getDataFromURI( + new CompressorStreamFactory().createCompressorInputStream( + new BufferedInputStream( + new URL("http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2") + .openStream + ) + ), + kdd2010aLineParser, + 8) + + // Placeholder for a mix server + var mixServExec: ExecutorService = _ + var assignedPort: Int = _ + + private def getDataFromURI( + in: InputStream, lineParseFunc: String => HivemallLabeledPoint, numPart: Int = 2) + : DataFrame = { + val reader = new BufferedReader(new InputStreamReader(in)) + try { + // Cache all data because stream closed soon + val lines = FileIterator(reader.readLine()).toSeq + val rdd = TestHive.sparkContext.parallelize(lines, numPart).map(lineParseFunc) + val df = rdd.toDF.cache + df.foreach(_ => {}) + df + } finally { + reader.close() + } + } + + before { + assert(mixServExec == null) + + // Launch a MIX server as thread + assignedPort = NetUtils.getAvailablePort + val method = classOf[MixServer].getDeclaredMethod("getOptions") + method.setAccessible(true) + val options = method.invoke(null).asInstanceOf[Options] + val cl = CommandLineUtils.parseOptions( + Array( + "-port", Integer.toString(assignedPort), + "-sync_threshold", "1" + ), + options + ) + val server = new MixServer(cl) + mixServExec = Executors.newSingleThreadExecutor() + mixServExec.submit(server) + var retry = 0 + while (server.getState() != MixServer.ServerState.RUNNING && retry < 32) { + Thread.sleep(100L) + retry += 1 + } + assert(MixServer.ServerState.RUNNING == server.getState) + } + + after { + mixServExec.shutdownNow() + mixServExec = null + } + + TestUtils.benchmark("model mixing test w/ regression") { + Seq( + "train_adadelta", + "train_adagrad", + "train_arow_regr", + "train_arowe_regr", + "train_arowe2_regr", + "train_logregr", + "train_pa1_regr", + "train_pa1a_regr", + "train_pa2_regr", + "train_pa2a_regr" + ).map { func => + // Build a model + val model = { + val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}" + val res = TestUtils.invokeFunc( + new HivemallOps(trainA9aData.part_amplify(lit(1))), + func, + Seq[Column]( + add_bias($"features"), + $"label", + lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " + + "-mix_cancel") + ) + ) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = testA9aData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .groupBy() + .agg(Map("target" -> "avg", "predicted" -> "avg")) + .toDF("target", "predicted") + + val (target, predicted) = eval.map { + case Row(target: Double, predicted: Double) => (target, predicted) + }.first + + // scalastyle:off println + println(s"func:${func} target:${target} predicted:${predicted} " + + s"diff:${Math.abs(target - predicted)}") + + testDf.unpersist() + } + } + + TestUtils.benchmark("model mixing test w/ binary classification") { + Seq( + "train_perceptron", + "train_pa", + "train_pa1", + "train_pa2", + "train_cw", + "train_arow", + "train_arowh", + "train_scw", + "train_scw2", + "train_adagrad_rda" + ).map { func => + // Build a model + val model = { + val groupId = s"${TestHive.sparkContext.applicationId}-${UUID.randomUUID}" + val res = TestUtils.invokeFunc( + new HivemallOps(trainKdd2010aData.part_amplify(lit(1))), + func, + Seq[Column]( + add_bias($"features"), + $"label", + lit(s"-mix localhost:${assignedPort} -mix_session ${groupId} -mix_threshold 2 " + + "-mix_cancel") + ) + ) + if (!res.columns.contains("conv")) { + res.groupBy("feature").agg("weight" -> "avg") + } else { + res.groupBy("feature").argmin_kld("weight", "conv") + } + }.toDF("feature", "weight") + + // Data preparation + val testDf = testKdd2010aData + .select(rowid(), $"label".as("target"), $"features") + .cache + + val testDf_exploded = testDf + .explode_array($"features") + .select($"rowid", extract_feature($"feature"), extract_weight($"feature")) + + // Do prediction + val predict = testDf_exploded + .join(model, testDf_exploded("feature") === model("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .select($"rowid", when(sigmoid($"sum(value)") > 0.50, 1.0).otherwise(0.0)) + .toDF("rowid", "predicted") + + // Evaluation + val eval = predict + .join(testDf, predict("rowid") === testDf("rowid")) + .where($"target" === $"predicted") + + // scalastyle:off println + println(s"func:${func} precision:${(eval.count + 0.0) / predict.count}") + + testDf.unpersist() + predict.unpersist() + } + } +} + +object FileIterator { + + def apply[A](f: => A): Iterator[A] = new Iterator[A] { + var opt = Option(f) + def hasNext = opt.nonEmpty + def next() = { + val r = opt.get + opt = Option(f) + r + } + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala new file mode 100644 index 000000000..89ed0866f --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/XGBoostSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import hivemall.xgboost._ + +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.VectorQueryTest +import org.apache.spark.sql.types._ + +final class XGBoostSuite extends VectorQueryTest { + import hiveContext.implicits._ + + private val defaultOptions = XGBoostOptions() + .set("num_round", "10") + .set("max_depth", "4") + + private val numModles = 3 + + private def countModels(dirPath: String): Int = { + new File(dirPath).listFiles().toSeq.count(_.getName.endsWith(".xgboost")) + } + + test("resolve libxgboost") { + def getProvidingClass(name: String): Class[_] = + DataSource(sparkSession = null, className = name).providingClass + assert(getProvidingClass("libxgboost") === + classOf[org.apache.spark.sql.hive.source.XGBoostFileFormat]) + } + + test("check XGBoost options") { + assert(s"$defaultOptions" == "-max_depth 4 -num_round 10") + val errMsg = intercept[IllegalArgumentException] { + defaultOptions.set("unknown", "3") + } + assert(errMsg.getMessage == "requirement failed: " + + "non-existing key detected in XGBoost options: unknown") + } + + test("train_xgboost_regr") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + // Save built models in persistent storage + mllibTrainDf.repartition(numModles) + .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + // Load the saved models + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").avg() + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select(predict("rowid"), $"predicted", $"label") + + result.select(avg(abs($"predicted" - $"label"))).collect.map { + case Row(diff: Double) => assert(diff > 0.0) + } + } + } + } + + test("train_xgboost_classifier") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + mllibTrainDf.repartition(numModles) + .train_xgboost_regr($"features", $"label", lit(s"${defaultOptions}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").avg() + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select( + when($"predicted" >= 0.50, 1).otherwise(0), + $"label".cast(IntegerType) + ) + .toDF("predicted", "label") + + assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0) + } + } + } + + test("train_xgboost_multiclass_classifier") { + withTempModelDir { tempDir => + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + + mllibTrainDf.repartition(numModles) + .train_xgboost_multiclass_classifier( + $"features", $"label", lit(s"${defaultOptions.set("num_class", "2")}")) + .write.format("libxgboost").save(tempDir) + + // Check #models generated by XGBoost + assert(countModels(tempDir) == numModles) + + val model = hiveContext.sparkSession.read.format("libxgboost").load(tempDir) + val predict = model.join(mllibTestDf) + .xgboost_multiclass_predict($"rowid", $"features", $"model_id", $"pred_model") + .groupBy("rowid").max_label("probability", "label") + .toDF("rowid", "predicted") + + val result = predict.join(mllibTestDf, predict("rowid") === mllibTestDf("rowid"), "INNER") + .select( + predict("rowid"), + $"predicted", + $"label".cast(IntegerType) + ) + + assert((result.where($"label" === $"predicted").count + 0.0) / result.count > 0.0) + } + } + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala new file mode 100644 index 000000000..5944ad44f --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive.benchmark + +import org.apache.spark.sql.{DataFrame, Dataset, Row, Column} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{Literal, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.benchmark.BenchmarkBaseAccessor +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.internal.HivemallOpsImpl._ +import org.apache.spark.sql.types._ +import org.apache.spark.test.TestUtils +import org.apache.spark.util.Benchmark + +class TestFuncWrapper(df: DataFrame) { + + def hive_each_top_k(k: Column, group: Column, value: Column, args: Column*) + : DataFrame = withTypedPlan { + planHiveGenericUDTF( + df.repartition(group).sortWithinPartitions(group), + "hivemall.tools.EachTopKUDTF", + "each_top_k", + Seq(k, group, value) ++ args, + Seq("rank", "key") ++ args.map { _.expr match { + case ua: UnresolvedAttribute => ua.name + }} + ) + } + + /** + * A convenient function to wrap a logical plan and produce a DataFrame. + */ + @inline private[this] def withTypedPlan(logicalPlan: => LogicalPlan): DataFrame = { + val queryExecution = df.sparkSession.sessionState.executePlan(logicalPlan) + val outputSchema = queryExecution.sparkPlan.schema + new Dataset[Row](df.sparkSession, queryExecution, RowEncoder(outputSchema)) + } +} + +object TestFuncWrapper { + + /** + * Implicitly inject the [[TestFuncWrapper]] into [[DataFrame]]. + */ + implicit def dataFrameToTestFuncWrapper(df: DataFrame): TestFuncWrapper = + new TestFuncWrapper(df) + + def sigmoid(exprs: Column*): Column = withExpr { + planHiveGenericUDF( + "hivemall.tools.math.SigmoidGenericUDF", + "sigmoid", + exprs + ) + } + + /** + * A convenient function to wrap an expression and produce a Column. + */ + @inline private def withExpr(expr: Expression): Column = Column(expr) +} + +class MiscBenchmark extends BenchmarkBaseAccessor { + + val numIters = 10 + + private def addBenchmarkCase(name: String, df: DataFrame)(implicit benchmark: Benchmark): Unit = { + benchmark.addCase(name, numIters) { + _ => df.queryExecution.executedPlan.execute().foreach(x => {}) + } + } + + TestUtils.benchmark("closure/exprs/spark-udf/hive-udf") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * sigmoid functions: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * -------------------------------------------------------------------------------- + * exprs 7708 / 8173 3.4 294.0 1.0X + * closure 7722 / 8342 3.4 294.6 1.0X + * spark-udf 7963 / 8350 3.3 303.8 1.0X + * hive-udf 13977 / 14050 1.9 533.2 0.6X + */ + import sparkSession.sqlContext.implicits._ + val N = 1L << 18 + val testDf = sparkSession.range(N).selectExpr("rand() AS value").cache + + // First, cache data + testDf.count + + implicit val benchmark = new Benchmark("sigmoid", N) + def sigmoidExprs(expr: Column): Column = { + val one: () => Literal = () => Literal.create(1.0, DoubleType) + Column(one()) / (Column(one()) + exp(-expr)) + } + addBenchmarkCase( + "exprs", + testDf.select(sigmoidExprs($"value")) + ) + implicit val encoder = RowEncoder(StructType(StructField("value", DoubleType) :: Nil)) + addBenchmarkCase( + "closure", + testDf.map { d => + Row(1.0 / (1.0 + Math.exp(-d.getDouble(0)))) + } + ) + val sigmoidUdf = udf { (d: Double) => 1.0 / (1.0 + Math.exp(-d)) } + addBenchmarkCase( + "spark-udf", + testDf.select(sigmoidUdf($"value")) + ) + addBenchmarkCase( + "hive-udf", + testDf.select(TestFuncWrapper.sigmoid($"value")) + ) + benchmark.run() + } + + TestUtils.benchmark("top-k query") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * top-k (k=100): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ------------------------------------------------------------------------------- + * rank 62748 / 62862 0.4 2393.6 1.0X + * each_top_k (hive-udf) 41421 / 41736 0.6 1580.1 1.5X + * each_top_k (exprs) 15793 / 16394 1.7 602.5 4.0X + */ + import sparkSession.sqlContext.implicits._ + import TestFuncWrapper._ + val topK = 100 + val N = 1L << 20 + val numGroup = 3 + val testDf = sparkSession.range(N).selectExpr( + s"id % $numGroup AS key", "rand() AS x", "CAST(id AS STRING) AS value" + ).cache + + // First, cache data + testDf.count + + implicit val benchmark = new Benchmark(s"top-k (k=$topK)", N) + addBenchmarkCase( + "rank", + testDf.withColumn("rank", rank().over(Window.partitionBy($"key").orderBy($"x".desc))) + .where($"rank" <= topK) + ) + addBenchmarkCase( + "each_top_k (hive-udf)", + testDf.hive_each_top_k(lit(topK), $"key", $"x", $"key", $"value") + ) + addBenchmarkCase( + "each_top_k (exprs)", + testDf.each_top_k(lit(topK), $"x".as("score"), $"key".as("group")) + ) + benchmark.run() + } + + TestUtils.benchmark("top-k join query") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * top-k join (k=3): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ------------------------------------------------------------------------------- + * join + rank 65959 / 71324 0.0 503223.9 1.0X + * join + each_top_k 66093 / 78864 0.0 504247.3 1.0X + * top_k_join 5013 / 5431 0.0 38249.3 13.2X + */ + import sparkSession.sqlContext.implicits._ + val topK = 3 + val N = 1L << 10 + val M = 1L << 10 + val numGroup = 3 + val inputDf = sparkSession.range(N).selectExpr( + s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS x", "rand() AS y" + ).cache + val masterDf = sparkSession.range(M).selectExpr( + s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y" + ).cache + + // First, cache data + inputDf.count + masterDf.count + + implicit val benchmark = new Benchmark(s"top-k join (k=$topK)", N) + // Define a score column + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ).as("score") + addBenchmarkCase( + "join + rank", + inputDf.join(masterDf, inputDf("group") === masterDf("group")) + .select(inputDf("group"), $"userId", $"posId", distance) + .withColumn( + "rank", rank().over(Window.partitionBy($"group", $"userId").orderBy($"score".desc))) + .where($"rank" <= topK) + ) + addBenchmarkCase( + "join + each_top_k", + inputDf.join(masterDf, inputDf("group") === masterDf("group")) + .each_top_k(lit(topK), distance, inputDf("group").as("group")) + ) + addBenchmarkCase( + "top_k_join", + inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === masterDf("group"), distance) + ) + benchmark.run() + } + + TestUtils.benchmark("codegen top-k join") { + /** + * Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + * Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + * + * top_k_join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + * ----------------------------------------------------------------------------------- + * top_k_join wholestage off 3 / 5 2751.9 0.4 1.0X + * top_k_join wholestage on 1 / 1 6494.4 0.2 2.4X + */ + val topK = 3 + val N = 1L << 23 + val M = 1L << 22 + val numGroup = 3 + val inputDf = sparkSession.range(N).selectExpr( + s"CAST(rand() * $numGroup AS INT) AS group", "id AS userId", "rand() AS x", "rand() AS y" + ).cache + val masterDf = sparkSession.range(M).selectExpr( + s"id % $numGroup AS group", "id AS posId", "rand() AS x", "rand() AS y" + ).cache + + // First, cache data + inputDf.count + masterDf.count + + // Define a score column + val distance = sqrt( + pow(inputDf("x") - masterDf("x"), lit(2.0)) + + pow(inputDf("y") - masterDf("y"), lit(2.0)) + ) + runBenchmark("top_k_join", N) { + inputDf.top_k_join(lit(topK), masterDf, inputDf("group") === masterDf("group"), + distance.as("score")) + } + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala new file mode 100644 index 000000000..bc656d100 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/test/HivemallFeatureQueryTest.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.hive.test + +import scala.collection.mutable.Seq +import scala.reflect.runtime.universe.TypeTag + +import hivemall.tools.RegressionDatagen + +import org.apache.spark.sql.{Column, QueryTest} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Base class for tests with Hivemall features. + */ +abstract class HivemallFeatureQueryTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + + import hiveContext.implicits._ + + protected val DummyInputData = Seq((0, 0)).toDF("c0", "c1") + + protected val IntList2Data = + Seq( + (8 :: 5 :: Nil, 6 :: 4 :: Nil), + (3 :: 1 :: Nil, 3 :: 2 :: Nil), + (2 :: Nil, 3 :: Nil) + ).toDF("target", "predict") + + protected val Float2Data = + Seq( + (0.8f, 0.3f), (0.3f, 0.9f), (0.2f, 0.4f) + ).toDF("target", "predict") + + protected val TinyTrainData = + Seq( + (0.0, "1:0.8" :: "2:0.2" :: Nil), + (1.0, "2:0.7" :: Nil), + (0.0, "1:0.9" :: Nil) + ).toDF("label", "features") + + protected val TinyTestData = + Seq( + (0.0, "1:0.6" :: "2:0.1" :: Nil), + (1.0, "2:0.9" :: Nil), + (0.0, "1:0.2" :: Nil), + (0.0, "2:0.1" :: Nil), + (0.0, "0:0.6" :: "2:0.4" :: Nil) + ).toDF("label", "features") + + protected val LargeRegrTrainData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100000, + seed = 3, + prob_one = 0.8f + ).cache + + protected val LargeRegrTestData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100, + seed = 3, + prob_one = 0.5f + ).cache + + protected val LargeClassifierTrainData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100000, + seed = 5, + prob_one = 0.8f, + cl = true + ).cache + + protected val LargeClassifierTestData = RegressionDatagen.exec( + hiveContext, + n_partitions = 2, + min_examples = 100, + seed = 5, + prob_one = 0.5f, + cl = true + ).cache +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala new file mode 100644 index 000000000..f8d5f7246 --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File +import java.nio.charset.StandardCharsets + +import com.google.common.io.Files +import org.apache.spark.sql.{QueryTest, DataFrame} +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils + +/** + * Base class for tests with SparkSQL VectorUDT data. + */ +abstract class VectorQueryTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + + private var trainDir: File = _ + private var testDir: File = _ + + // A `libsvm` schema is (Double, ml.linalg.Vector) + protected var mllibTrainDf: DataFrame = _ + protected var mllibTestDf: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val trainLines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 2:4.0 4:5.0 6:6.0 + |1 1:1.1 4:1.0 5:2.3 7:1.0 + |1 1:1.0 4:1.5 5:2.1 7:1.2 + """.stripMargin + trainDir = Utils.createTempDir() + Files.write(trainLines, new File(trainDir, "train-00000"), StandardCharsets.UTF_8) + val testLines = + """ + |1 1:1.3 3:2.1 5:2.8 + |0 2:3.9 4:5.3 6:8.0 + """.stripMargin + testDir = Utils.createTempDir() + Files.write(testLines, new File(testDir, "test-00000"), StandardCharsets.UTF_8) + + mllibTrainDf = spark.read.format("libsvm").load(trainDir.getAbsolutePath) + // Must be cached because rowid() is deterministic + mllibTestDf = spark.read.format("libsvm").load(testDir.getAbsolutePath) + .withColumn("rowid", rowid()).cache + } + + override def afterAll(): Unit = { + try { + Utils.deleteRecursively(trainDir) + Utils.deleteRecursively(testDir) + } finally { + super.afterAll() + } + } + + protected def withTempModelDir(f: String => Unit): Unit = { + var tempDir: File = null + try { + tempDir = Utils.createTempDir() + f(tempDir.getAbsolutePath + "/xgboost_models") + } catch { + case e: Throwable => fail(s"Unexpected exception detected: ${e}") + } finally { + Utils.deleteRecursively(tempDir) + } + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala new file mode 100644 index 000000000..0e1372dcc --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/streaming/HivemallOpsWithFeatureSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.streaming + +import scala.reflect.ClassTag + +import org.apache.spark.ml.feature.HivemallLabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.test.HivemallFeatureQueryTest +import org.apache.spark.streaming.HivemallStreamingOps._ +import org.apache.spark.streaming.dstream.InputDStream +import org.apache.spark.streaming.scheduler.StreamInputInfo + +/** + * This is an input stream just for tests. + */ +private[this] class TestInputStream[T: ClassTag]( + ssc: StreamingContext, + input: Seq[Seq[T]], + numPartitions: Int) extends InputDStream[T](ssc) { + + override def start() {} + + override def stop() {} + + override def compute(validTime: Time): Option[RDD[T]] = { + logInfo("Computing RDD for time " + validTime) + val index = ((validTime - zeroTime) / slideDuration - 1).toInt + val selectedInput = if (index < input.size) input(index) else Seq[T]() + + // lets us test cases where RDDs are not created + if (selectedInput == null) { + return None + } + + // Report the input data's information to InputInfoTracker for testing + val inputInfo = StreamInputInfo(id, selectedInput.length.toLong) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) + logInfo("Created RDD " + rdd.id + " with " + selectedInput) + Some(rdd) + } +} + +final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { + + // This implicit value used in `HivemallStreamingOps` + implicit val sqlCtx = hiveContext + + /** + * Run a block of code with the given StreamingContext. + * This method do not stop a given SparkContext because other tests share the context. + */ + private def withStreamingContext[R](ssc: StreamingContext)(block: StreamingContext => R): Unit = { + try { + block(ssc) + ssc.start() + ssc.awaitTerminationOrTimeout(10 * 1000) // 10s wait + } finally { + try { + ssc.stop(stopSparkContext = false) + } catch { + case e: Exception => logError("Error stopping StreamingContext", e) + } + } + } + + // scalastyle:off line.size.limit + + /** + * This test below fails sometimes (too flaky), so we temporarily ignore it. + * The stacktrace of this failure is: + * + * HivemallOpsWithFeatureSuite: + * Exception in thread "broadcast-exchange-60" java.lang.OutOfMemoryError: Java heap space + * at java.nio.HeapByteBuffer.(HeapByteBuffer.java:57) + * at java.nio.ByteBuffer.allocate(ByteBuffer.java:331) + * at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231) + * at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$4.apply(TorrentBroadcast.scala:231) + * at org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:78) + * at org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:65) + * at net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:205) + * at net.jpountz.lz4.LZ4BlockOutputStream.finish(LZ4BlockOutputStream.java:235) + * at net.jpountz.lz4.LZ4BlockOutputStream.close(LZ4BlockOutputStream.java:175) + * at java.io.ObjectOutputStream$BlockDataOutputStream.close(ObjectOutputStream.java:1827) + * at java.io.ObjectOutputStream.close(ObjectOutputStream.java:741) + * at org.apache.spark.serializer.JavaSerializationStream.close(JavaSerializer.scala:57) + * at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$blockifyObject$1.apply$mcV$sp(TorrentBroadcast.scala:238) + * at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1296) + * at org.apache.spark.broadcast.TorrentBroadcast$.blockifyObject(TorrentBroadcast.scala:237) + * at org.apache.spark.broadcast.TorrentBroadcast.writeBlocks(TorrentBroadcast.scala:107) + * at org.apache.spark.broadcast.TorrentBroadcast.(TorrentBroadcast.scala:86) + * at org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34) + * ... + */ + + // scalastyle:on line.size.limit + + ignore("streaming") { + import sqlCtx.implicits._ + + // We assume we build a model in advance + val testModel = Seq( + ("0", 0.3f), ("1", 0.1f), ("2", 0.6f), ("3", 0.2f) + ).toDF("feature", "weight") + + withStreamingContext(new StreamingContext(sqlCtx.sparkContext, Milliseconds(100))) { ssc => + val inputData = Seq( + Seq(HivemallLabeledPoint(features = "1:0.6" :: "2:0.1" :: Nil)), + Seq(HivemallLabeledPoint(features = "2:0.9" :: Nil)), + Seq(HivemallLabeledPoint(features = "1:0.2" :: Nil)), + Seq(HivemallLabeledPoint(features = "2:0.1" :: Nil)), + Seq(HivemallLabeledPoint(features = "0:0.6" :: "2:0.4" :: Nil)) + ) + + val inputStream = new TestInputStream[HivemallLabeledPoint](ssc, inputData, 1) + + // Apply predictions on input streams + val prediction = inputStream.predict { streamDf => + val df = streamDf.select(rowid(), $"features").explode_array($"features") + val testDf = df.select( + // TODO: `$"feature"` throws AnalysisException, why? + $"rowid", extract_feature(df("feature")), extract_weight(df("feature")) + ) + testDf.join(testModel, testDf("feature") === testModel("feature"), "LEFT_OUTER") + .select($"rowid", ($"weight" * $"value").as("value")) + .groupBy("rowid").sum("value") + .toDF("rowid", "value") + .select($"rowid", sigmoid($"value")) + } + + // Dummy output stream + prediction.foreachRDD(_ => {}) + } + } +} diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/test/TestUtils.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/test/TestUtils.scala new file mode 100644 index 000000000..fa7b6e54e --- /dev/null +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/test/TestUtils.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.test + +import scala.reflect.runtime.{universe => ru} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.DataFrame + +object TestUtils extends Logging { + + // Do benchmark if INFO-log enabled + def benchmark(benchName: String)(testFunc: => Unit): Unit = { + if (log.isDebugEnabled) { + testFunc + } + } + + def expectResult(res: Boolean, errMsg: String): Unit = if (res) { + logWarning(errMsg) + } + + def invokeFunc(cls: Any, func: String, args: Any*): DataFrame = try { + // Invoke a function with the given name via reflection + val im = scala.reflect.runtime.currentMirror.reflect(cls) + val mSym = im.symbol.typeSignature.member(ru.newTermName(func)).asMethod + im.reflectMethod(mSym).apply(args: _*) + .asInstanceOf[DataFrame] + } catch { + case e: Exception => + assert(false, s"Invoking ${func} failed because: ${e.getMessage}") + null // Not executed + } +} + +// TODO: Any same function in o.a.spark.*? +class TestFPWrapper(d: Double) { + + // Check an equality between Double/Float values + def ~==(d: Double): Boolean = Math.abs(this.d - d) < 0.001 +} + +object TestFPWrapper { + + @inline implicit def toTestFPWrapper(d: Double): TestFPWrapper = { + new TestFPWrapper(d) + } +} From 1a7897d3462e10b38a618e65dee756a6d9b318c7 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 16 Mar 2018 19:58:58 +0900 Subject: [PATCH 2/3] Fix style errors --- bin/run_travis_tests.sh | 6 ++---- .../scala/org/apache/spark/sql/hive/ModelMixingSuite.scala | 7 ++++--- .../apache/spark/sql/hive/benchmark/MiscBenchmark.scala | 4 ++-- .../scala/org/apache/spark/sql/test/VectorQueryTest.scala | 3 ++- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/bin/run_travis_tests.sh b/bin/run_travis_tests.sh index d0ad8c206..9693520ab 100755 --- a/bin/run_travis_tests.sh +++ b/bin/run_travis_tests.sh @@ -35,14 +35,12 @@ cd $HIVEMALL_HOME/spark export MAVEN_OPTS="-XX:MaxPermSize=256m" -mvn -q scalastyle:check -Pspark-2.0 -pl spark-2.0 -am test -Dtest=none - -mvn -q scalastyle:check clean -Pspark-2.1 -pl spark-2.1 -am test -Dtest=none +mvn -q scalastyle:check -pl spark-2.0,spark-2.1 -am test # spark-2.2 runs on Java 8+ if [[ ! -z "$(java -version 2>&1 | grep 1.8)" ]]; then mvn -q scalastyle:check clean -Djava.source.version=1.8 -Djava.target.version=1.8 \ - -Pspark-2.2 -pl spark-2.2 -am test -Dtest=none + -pl spark-2.2,spark-2.3 -am test fi exit 0 diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala index 267179882..ad23e8f32 100644 --- a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/ModelMixingSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive -import java.io.{BufferedInputStream, InputStream, InputStreamReader, BufferedReader} +import java.io.{BufferedInputStream, BufferedReader, InputStream, InputStreamReader} import java.net.URL import java.util.UUID import java.util.concurrent.{Executors, ExecutorService} @@ -29,10 +29,11 @@ import hivemall.utils.lang.CommandLineUtils import hivemall.utils.net.NetUtils import org.apache.commons.cli.Options import org.apache.commons.compress.compressors.CompressorStreamFactory -import org.apache.spark.SparkFunSuite import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.HivemallLabeledPoint -import org.apache.spark.sql.{DataFrame, Row, Column} +import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HivemallGroupedDataset._ import org.apache.spark.sql.hive.HivemallOps._ diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala index 5944ad44f..0a9e4a65d 100644 --- a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/hive/benchmark/MiscBenchmark.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive.benchmark -import org.apache.spark.sql.{DataFrame, Dataset, Row, Column} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{Literal, Expression} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.benchmark.BenchmarkBaseAccessor import org.apache.spark.sql.expressions.Window diff --git a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala index f8d5f7246..4e2a0c188 100644 --- a/spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala +++ b/spark/spark-2.3/src/test/scala/org/apache/spark/sql/test/VectorQueryTest.scala @@ -23,7 +23,8 @@ import java.io.File import java.nio.charset.StandardCharsets import com.google.common.io.Files -import org.apache.spark.sql.{QueryTest, DataFrame} + +import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.hive.HivemallOps._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.util.Utils From 6000c44ea55f206a249d12b04ee3c303f93b1624 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 19 Mar 2018 12:57:01 +0900 Subject: [PATCH 3/3] Fix bugs in Generate --- spark/pom.xml | 5 ++++- .../org/apache/spark/sql/hive/HivemallOps.scala | 12 +++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/spark/pom.xml b/spark/pom.xml index f0827ea33..27bb6dbe8 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -158,7 +158,10 @@ org.apache.hivemall:hivemall-spark-common org.apache.hivemall:hivemall-core - io.netty:netty-all + com.github.haifengl:smile-core com.github.haifengl:smile-math com.github.haifengl:smile-data diff --git a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 8323d2286..94bcfd62b 100644 --- a/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.3/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.logical.{Generate, JoinTopK, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisBarrier, Generate, JoinTopK, LogicalPlan} import org.apache.spark.sql.execution.UserProvidedPlanner import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv} import org.apache.spark.sql.functions._ @@ -991,16 +991,14 @@ final class HivemallOps(df: DataFrame) extends Logging { k = kInt, scoreExpr = scoreExpr, groupExprs = groupExprs, - elementSchema = StructType( - rankField +: inputAttrs.map(d => StructField(d.name, d.dataType)) - ), + elementSchema = StructType(rankField :: Nil), children = inputAttrs ), - unrequiredChildIndex = Seq.empty, + unrequiredChildIndex = Nil, outer = false, qualifier = None, - generatorOutput = Seq(rankField.name).map(UnresolvedAttribute(_)) ++ inputAttrs, - child = analyzedPlan + generatorOutput = Nil, + child = AnalysisBarrier(analyzedPlan) ) }