Skip to content
Permalink
Browse files
Merge pull request #339 from maropu/InputVectorSupport
Support SparkSQL user-defined vector types in Hivemall functions
  • Loading branch information
myui committed Sep 11, 2016
2 parents 27a5f2d + dec145f commit c7acee98da3c17613955c1afb9e1bafcb8d1fe26
Showing 20 changed files with 823 additions and 516 deletions.
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import java.util.UUID

import org.apache.spark.Logging
import org.apache.spark.ml.feature.HmFeature
import org.apache.spark.ml.feature.HivemallFeature
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NamedExpression}
@@ -628,7 +628,7 @@ final class HivemallOps(df: DataFrame) extends Logging {
def explode_array(expr: Column): DataFrame = {
df.explode(expr) { case Row(v: Seq[_]) =>
// Type erasure removes the component type in Seq
v.map(s => HmFeature(s.asInstanceOf[String]))
v.map(s => HivemallFeature(s.asInstanceOf[String]))
}
}

@@ -22,12 +22,12 @@ import org.apache.spark.SparkFunSuite
class HivemallLabeledPointSuite extends SparkFunSuite {

test("toString") {
val lp = HmLabeledPoint(1.0f, Seq("1:0.5", "3:0.3", "8:0.1"))
val lp = HivemallLabeledPoint(1.0f, Seq("1:0.5", "3:0.3", "8:0.1"))
assert(lp.toString === "1.0,[1:0.5,3:0.3,8:0.1]")
}

test("parse") {
val lp = HmLabeledPoint.parse("1.0,[1:0.5,3:0.3,8:0.1]")
val lp = HivemallLabeledPoint.parse("1.0,[1:0.5,3:0.3,8:0.1]")
assert(lp.label === 1.0)
assert(lp.features === Seq("1:0.5", "3:0.3", "8:0.1"))
}
@@ -25,7 +25,7 @@ import java.util.concurrent.{ExecutorService, Executors}
import org.apache.commons.cli.Options
import org.apache.commons.compress.compressors.CompressorStreamFactory
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.HmLabeledPoint
import org.apache.spark.ml.feature.HivemallLabeledPoint
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
@@ -46,7 +46,7 @@ final class ModelMixingSuite extends SparkFunSuite with BeforeAndAfter {
val a9aLineParser = (line: String) => {
val elements = line.split(" ")
val (label, features) = (elements.head, elements.tail)
HmLabeledPoint(if (label == "+1") 1.0f else 0.0f, features)
HivemallLabeledPoint(if (label == "+1") 1.0f else 0.0f, features)
}

lazy val trainA9aData: DataFrame =
@@ -63,7 +63,7 @@ final class ModelMixingSuite extends SparkFunSuite with BeforeAndAfter {
val kdd2010aLineParser = (line: String) => {
val elements = line.split(" ")
val (label, features) = (elements.head, elements.tail)
HmLabeledPoint(if (label == "1") 1.0f else 0.0f, features)
HivemallLabeledPoint(if (label == "1") 1.0f else 0.0f, features)
}

lazy val trainKdd2010aData: DataFrame =
@@ -93,7 +93,7 @@ final class ModelMixingSuite extends SparkFunSuite with BeforeAndAfter {
var assignedPort: Int = _

private def getDataFromURI(
in: InputStream, lineParseFunc: String => HmLabeledPoint, numPart: Int = 2): DataFrame = {
in: InputStream, lineParseFunc: String => HivemallLabeledPoint, numPart: Int = 2): DataFrame = {
val reader = new BufferedReader(new InputStreamReader(in))
try {
// Cache all data because stream closed soon
@@ -19,7 +19,7 @@ package org.apache.spark.streaming

import reflect.ClassTag

import org.apache.spark.ml.feature.HmLabeledPoint
import org.apache.spark.ml.feature.HivemallLabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.streaming.HivemallStreamingOps._
@@ -92,14 +92,14 @@ final class HivemallOpsSuite extends HivemallQueryTest {

withStreamingContext(new StreamingContext(sqlCtx.sparkContext, Milliseconds(100))) { ssc =>
val inputData = Seq(
Seq(HmLabeledPoint(features = "1:0.6" :: "2:0.1" :: Nil)),
Seq(HmLabeledPoint(features = "2:0.9" :: Nil)),
Seq(HmLabeledPoint(features = "1:0.2" :: Nil)),
Seq(HmLabeledPoint(features = "2:0.1" :: Nil)),
Seq(HmLabeledPoint(features = "0:0.6" :: "2:0.4" :: Nil))
Seq(HivemallLabeledPoint(features = "1:0.6" :: "2:0.1" :: Nil)),
Seq(HivemallLabeledPoint(features = "2:0.9" :: Nil)),
Seq(HivemallLabeledPoint(features = "1:0.2" :: Nil)),
Seq(HivemallLabeledPoint(features = "2:0.1" :: Nil)),
Seq(HivemallLabeledPoint(features = "0:0.6" :: "2:0.4" :: Nil))
)

val inputStream = new TestInputStream[HmLabeledPoint](ssc, inputData, 1)
val inputStream = new TestInputStream[HivemallLabeledPoint](ssc, inputData, 1)

// Apply predictions on input streams
val prediction = inputStream.predict { streamDf =>
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hivemall.xgboost

import scala.collection.mutable

import org.apache.commons.cli.Options
import org.apache.spark.annotation.AlphaComponent

/**
* :: AlphaComponent ::
* An utility class to generate a sequence of options used in XGBoost.
*/
@AlphaComponent
case class XGBoostOptions() {
private val params: mutable.Map[String, String] = mutable.Map.empty
private val options: Options = {
new XGBoostUDTF() {
def options(): Options = super.getOptions()
}.options()
}

private def isValidKey(key: String): Boolean = {
// TODO: Is there another way to handle all the XGBoost options?
options.hasOption(key) || key == "num_class"
}

def set(key: String, value: String): XGBoostOptions = {
require(isValidKey(key), s"non-existing key detected in XGBoost options: ${key}")
params.put(key, value)
this
}

def help(): Unit = {
import scala.collection.JavaConversions._
options.getOptions.map { case option => println(option) }
}

override def toString(): String = {
params.map { case (key, value) => s"-$key $value" }.mkString(" ")
}
}
@@ -0,0 +1,33 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hivemall

import org.apache.spark.sql.hive.source.XGBoostFileFormat

package object xgboost {

/**
* Model files for libxgboost are loaded as follows;
*
* import HivemallOps._
* val modelDf = sparkSession.read.format(xgboostFormat).load(modelDir.getCanonicalPath)
*/
val xgboost = classOf[XGBoostFileFormat].getName
}

0 comments on commit c7acee9

Please sign in to comment.