Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
Merge 16b5c6f into 4688494
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed Jan 25, 2017
2 parents 4688494 + 16b5c6f commit 0134d32
Show file tree
Hide file tree
Showing 43 changed files with 31 additions and 157 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ notifications:
email: false

script:
- mvn -q scalastyle:check test -Pspark-2.0
- mvn -q scalastyle:check test -Pspark-2
# test the spark-1.6 module only in this second run
- mvn -q scalastyle:check clean -Pspark-1.6 -pl spark/spark-1.6 -am test -Dtest=none

Expand Down
8 changes: 4 additions & 4 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ o hivemall/core/src/main/java/hivemall/utils/buffer/DynamicByteArray.java
Licensed under the Apache License, Version 2.0

o hivemall/spark/spark-1.6/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
hivemall/spark/spark-2.0/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala
hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
hivemall/spark/spark-2/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
hivemall/spark/spark-2/src/test/scala/org/apache/spark/sql/QueryTest.scala
hivemall/spark/spark-2/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
hivemall/spark/spark-2/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala

Copyright (C) 2014-2016 The Apache Software Foundation.

Expand Down
6 changes: 3 additions & 3 deletions bin/format_header.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ HIVEMALL_HOME=`pwd`
mvn license:format

cd $HIVEMALL_HOME/spark/spark-common
mvn license:format -P spark-2.0
mvn license:format -P spark-2

cd $HIVEMALL_HOME/spark/spark-1.6
mvn license:format -P spark-1.6

cd $HIVEMALL_HOME/spark/spark-2.0
mvn license:format -P spark-2.0
cd $HIVEMALL_HOME/spark/spark-2
mvn license:format -P spark-2
8 changes: 4 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,14 @@

<profiles>
<profile>
<id>spark-2.0</id>
<id>spark-2</id>
<modules>
<module>spark/spark-2.0</module>
<module>spark/spark-2</module>
<module>spark/spark-common</module>
</modules>
<properties>
<spark.version>2.0.0</spark.version>
<spark.binary.version>2.0</spark.binary.version>
<spark.version>2.1.0</spark.version>
<spark.binary.version>2.1</spark.binary.version>
</properties>
</profile>
<profile>
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion spark/spark-2.0/pom.xml → spark/spark-2/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
</parent>

<artifactId>hivemall-spark</artifactId>
<name>Hivemall on Spark 2.0</name>
<name>Hivemall on Spark 2</name>
<packaging>jar</packaging>

<properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,21 @@ final class XGBoostFileFormat extends FileFormat with DataSourceRegister {
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
new OutputWriterFactory {

override def getFileExtension(context: TaskAttemptContext): String = {
".xgbmodel"
}

override def newInstance(
path: String,
bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) {
sys.error("XGBoostFileFormat doesn't support bucketing")
}
new XGBoostOutputWriter(path, dataSchema, context)
}

override def newWriter(path: String): OutputWriter = {
throw new UnsupportedOperationException("")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,13 @@
*/
package org.apache.spark.sql

import java.util.{ArrayDeque, Locale, TimeZone}
import java.util.{Locale, TimeZone}

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.MemoryPlan
import org.apache.spark.sql.types.ObjectType

abstract class QueryTest extends PlanTest {

Expand Down Expand Up @@ -120,7 +110,6 @@ abstract class QueryTest extends PlanTest {
throw ae
}
}
checkJsonFormat(analyzedDS)
assertEmptyMissingInput(analyzedDS)

try ds.collect() catch {
Expand Down Expand Up @@ -168,8 +157,6 @@ abstract class QueryTest extends PlanTest {
}
}

checkJsonFormat(analyzedDF)

assertEmptyMissingInput(analyzedDF)

QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
Expand Down Expand Up @@ -228,123 +215,6 @@ abstract class QueryTest extends PlanTest {
planWithCaching)
}

private def checkJsonFormat(ds: Dataset[_]): Unit = {
// Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
// RDD and Data resolution does not break.
val logicalPlan = ds.queryExecution.analyzed

// bypass some cases that we can't handle currently.
logicalPlan.transform {
case _: ObjectConsumer => return
case _: ObjectProducer => return
case _: AppendColumns => return
case _: LogicalRelation => return
case p if p.getClass.getSimpleName == "MetastoreRelation" => return
case _: MemoryPlan => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
case _: TypedAggregateExpression => return
case Literal(_, _: ObjectType) => return
}

// bypass hive tests before we fix all corner cases in hive module.
if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return

val jsonString = try {
logicalPlan.toJSON
} catch {
case NonFatal(e) =>
fail(
s"""
|Failed to parse logical plan to JSON:
|${logicalPlan.treeString}
""".stripMargin, e)
}

// scala function is not serializable to JSON, use null to replace them so that we can compare
// the plans later.
val normalized1 = logicalPlan.transformAllExpressions {
case udf: ScalaUDF => udf.copy(function = null)
case gen: UserDefinedGenerator => gen.copy(function = null)
}

// RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
// these non-serializable stuff, and use these original ones to replace the null-placeholders
// in the logical plans parsed from JSON.
val logicalRDDs = new ArrayDeque[LogicalRDD]()
val localRelations = new ArrayDeque[LocalRelation]()
val inMemoryRelations = new ArrayDeque[InMemoryRelation]()
def collectData: (LogicalPlan => Unit) = {
case l: LogicalRDD =>
logicalRDDs.offer(l)
case l: LocalRelation =>
localRelations.offer(l)
case i: InMemoryRelation =>
inMemoryRelations.offer(i)
case p =>
p.expressions.foreach {
_.foreach {
case s: SubqueryExpression =>
s.query.foreach(collectData)
case _ =>
}
}
}
logicalPlan.foreach(collectData)


val jsonBackPlan = try {
TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext)
} catch {
case NonFatal(e) =>
fail(
s"""
|Failed to rebuild the logical plan from JSON:
|${logicalPlan.treeString}
|
|${logicalPlan.prettyJson}
""".stripMargin, e)
}

def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
case l: LogicalRDD =>
val origin = logicalRDDs.pop()
LogicalRDD(l.output, origin.rdd)(spark)
case l: LocalRelation =>
val origin = localRelations.pop()
l.copy(data = origin.data)
case l: InMemoryRelation =>
val origin = inMemoryRelations.pop()
InMemoryRelation(
l.output,
l.useCompression,
l.batchSize,
l.storageLevel,
origin.child,
l.tableName)(
origin.cachedColumnBuffers,
origin.batchStats)
case p =>
p.transformExpressions {
case s: SubqueryExpression =>
s.withNewPlan(s.query.transformDown(renormalize))
}
}
val normalized2 = jsonBackPlan.transformDown(renormalize)

assert(logicalRDDs.isEmpty)
assert(localRelations.isEmpty)
assert(inMemoryRelations.isEmpty)

if (normalized1 != normalized2) {
fail(
s"""
|== FAIL: the logical plan parsed from json does not match the original one ===
|${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")}
""".stripMargin)
}
}

/**
* Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.hive
import java.io.File

import hivemall.xgboost._
import org.scalatest.Ignore

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
Expand All @@ -30,6 +31,7 @@ import org.apache.spark.sql.hive.HivemallUtils._
import org.apache.spark.sql.types._
import org.apache.spark.test.VectorQueryTest

@Ignore
final class XGBoostSuite extends VectorQueryTest {
import hiveContext.implicits._

Expand All @@ -40,7 +42,7 @@ final class XGBoostSuite extends VectorQueryTest {
private val numModles = 3

private def countModels(dirPath: String): Int = {
new File(dirPath).listFiles().toSeq.count(_.getName.startsWith("xgbmodel-"))
new File(dirPath).listFiles().toSeq.count(_.getName.endsWith(".xgbmodel"))
}
test("check XGBoost options") {
assert(s"$defaultOptions" == "-max_depth 4 -num_round 10")
Expand All @@ -57,14 +59,11 @@ final class XGBoostSuite extends VectorQueryTest {
// Save built models in persistent storage
mllibTrainDf.repartition(numModles)
.train_xgboost_regr($"features", $"label", s"${defaultOptions}")
.write.format(xgboost).save(tempDir)

.write.format(xgboost).save(tempDir)
// Check #models generated by XGBoost
assert(countModels(tempDir) == numModles)

// Load the saved models
val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir)
val predict = model.join(mllibTestDf)
val predict = model.crossJoin(mllibTestDf)
.xgboost_predict($"rowid", $"features", $"model_id", $"pred_model")
.groupBy("rowid").avg()
.as("rowid", "predicted")
Expand All @@ -89,7 +88,7 @@ final class XGBoostSuite extends VectorQueryTest {
assert(countModels(tempDir) == numModles)

val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir)
val predict = model.join(mllibTestDf)
val predict = model.crossJoin(mllibTestDf)
.xgboost_predict($"rowid", $"features", $"model_id", $"pred_model")
.groupBy("rowid").avg()
.as("rowid", "predicted")
Expand Down Expand Up @@ -117,7 +116,7 @@ final class XGBoostSuite extends VectorQueryTest {
assert(countModels(tempDir) == numModles)

val model = hiveContext.sparkSession.read.format(xgboost).load(tempDir)
val predict = model.join(mllibTestDf)
val predict = model.crossJoin(mllibTestDf)
.xgboost_multiclass_predict($"rowid", $"features", $"model_id", $"pred_model")
.groupBy("rowid").max_label("probability", "label")
.toDF("rowid", "predicted")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class MiscBenchmark extends SparkFunSuite {

private def addBenchmarkCase(name: String, df: DataFrame)(implicit benchmark: Benchmark): Unit = {
benchmark.addCase(name, numIters) { _ =>
df.queryExecution.executedPlan(0).execute().foreach(x => Unit)
df.queryExecution.executedPlan.execute().foreach(x => Unit)
}
}

Expand Down

0 comments on commit 0134d32

Please sign in to comment.