From fc3a6fe67f5aeda2443958c31f097daeba8549e5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 00:08:16 -0700 Subject: [PATCH 01/18] [SPARK-8749][SQL] Remove HiveTypeCoercion trait. Moved all the rules into the companion object. Author: Reynold Xin Closes #7147 from rxin/SPARK-8749 and squashes the following commits: c1c6dc0 [Reynold Xin] [SPARK-8749][SQL] Remove HiveTypeCoercion trait. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 59 ++++++++----------- .../analysis/HiveTypeCoercionSuite.scala | 14 ++--- 3 files changed, 33 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 117c87a785fdb..15e84e68b9881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -43,7 +43,7 @@ class Analyzer( registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { + extends RuleExecutor[LogicalPlan] with CheckAnalysis { def resolver: Resolver = { if (conf.caseSensitiveAnalysis) { @@ -76,7 +76,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - typeCoercionRules ++ + HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e525ad623ff12..a9d396d1faeeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -22,7 +22,32 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ + +/** + * A collection of [[Rule Rules]] that can be used to coerce differing types that + * participate in operations into compatible ones. Most of these rules are based on Hive semantics, + * but they do not introduce any dependencies on the hive codebase. For this reason they remain in + * Catalyst until we have a more standard set of coercions. + */ object HiveTypeCoercion { + + val typeCoercionRules = + PropagateTypes :: + ConvertNaNs :: + InConversion :: + WidenTypes :: + PromoteStrings :: + DecimalPrecision :: + BooleanEquality :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CaseWhenCoercion :: + IfCoercion :: + Division :: + PropagateTypes :: + AddCastForAutoCastInputTypes :: + Nil + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: private val numericPrecedence = @@ -79,7 +104,6 @@ object HiveTypeCoercion { }) } - /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -90,34 +114,6 @@ object HiveTypeCoercion { case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } -} - -/** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. - */ -trait HiveTypeCoercion { - - import HiveTypeCoercion._ - - val typeCoercionRules = - PropagateTypes :: - ConvertNaNs :: - InConversion :: - WidenTypes :: - PromoteStrings :: - DecimalPrecision :: - BooleanEquality :: - StringToIntegralCasts :: - FunctionArgumentConversion :: - CaseWhenCoercion :: - IfCoercion :: - Division :: - PropagateTypes :: - AddCastForAutoCastInputTypes :: - Nil /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to @@ -202,8 +198,6 @@ trait HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -655,8 +649,6 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") @@ -714,7 +706,6 @@ trait HiveTypeCoercion { * [[AutoCastInputTypes]]. */ object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f7b8e21bed490..eae3666595a38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -113,8 +113,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("coalesce casts") { - val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -123,7 +122,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -135,7 +134,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for If") { - val rule = new HiveTypeCoercion { }.IfCoercion + val rule = HiveTypeCoercion.IfCoercion ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) @@ -148,19 +147,18 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for CaseKeyWhen") { - val cwc = new HiveTypeCoercion {}.CaseWhenCoercion - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) } test("type coercion simplification for equal to") { - val be = new HiveTypeCoercion {}.BooleanEquality + val be = HiveTypeCoercion.BooleanEquality ruleTest(be, EqualTo(Literal(true), Literal(1)), From 0eee0615894cda8ae1b2c8e61b8bda0ff648a219 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Jul 2015 01:02:33 -0700 Subject: [PATCH 02/18] [SQL] [MINOR] remove internalRowRDD in DataFrame Developers have already familiar with `queryExecution.toRDD` as internal row RDD, and we should not add new concept. Author: Wenchen Fan Closes #7116 from cloud-fan/internal-rdd and squashes the following commits: 24756ca [Wenchen Fan] remove internalRowRDD --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 +--- .../org/apache/spark/sql/execution/stat/FrequentItems.scala | 2 +- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 2 +- .../main/scala/org/apache/spark/sql/sources/commands.scala | 4 ++-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8fe1f7e34cb5e..caad2da80b1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1469,14 +1469,12 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - internalRowRdd.mapPartitions { rows => + queryExecution.toRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } - private[sql] def internalRowRdd = queryExecution.executedPlan.execute() - /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 3ebbf96090a55..4e2e2c210d5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index b624ef7e8fa1a..23ddfa9839e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -82,7 +82,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 42b51caab5ce9..7214eb0b4169a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => From 97652416e22ae7d4c471178377a7dda61afb1f7a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 01:08:20 -0700 Subject: [PATCH 03/18] [SPARK-8750][SQL] Remove the closure in functions.callUdf. Author: Reynold Xin Closes #7148 from rxin/calludf-closure and squashes the following commits: 00df372 [Reynold Xin] Fixed index out of bound exception. 4beba76 [Reynold Xin] [SPARK-8750][SQL] Remove the closure in functions.callUdf. --- .../main/scala/org/apache/spark/sql/functions.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5767668dd339b..4e8f3f96bf4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1829,7 +1829,15 @@ object functions { */ @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + // Note: we avoid using closures here because on file systems that are case-insensitive, the + // compiled class file for the closure here will conflict with the one in callUDF (upper case). + val exprs = new Array[Expression](cols.size) + var i = 0 + while (i < cols.size) { + exprs(i) = cols(i).expr + i += 1 + } + UnresolvedFunction(udfName, exprs) } } From fdcad6ef48a9e790776c316124bd6478ab6bd5c8 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Wed, 1 Jul 2015 09:37:09 -0700 Subject: [PATCH 04/18] [SPARK-8763] [PYSPARK] executing run-tests.py with Python 2.6 fails with absence of subprocess.check_output function Running run-tests.py with Python 2.6 cause following error: ``` Running PySpark tests. Output is in python//Users/tomohiko/.jenkins/jobs/pyspark_test/workspace/python/unit-tests.log Will test against the following Python executables: ['python2.6', 'python3.4', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Traceback (most recent call last): File "./python/run-tests.py", line 196, in main() File "./python/run-tests.py", line 159, in main python_implementation = subprocess.check_output( AttributeError: 'module' object has no attribute 'check_output' ... ``` The cause of this error is using subprocess.check_output function, which exists since Python 2.7. (ref. https://docs.python.org/2.7/library/subprocess.html#subprocess.check_output) Author: cocoatomo Closes #7161 from cocoatomo/issues/8763-test-fails-py26 and squashes the following commits: cf4f901 [cocoatomo] [SPARK-8763] backport process.check_output function from Python 2.7 --- python/run-tests.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index b7737650daa54..7638854def2e8 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,6 +31,23 @@ import Queue else: import queue as Queue +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -156,11 +173,11 @@ def main(): task_queue = Queue.Queue() for python_exec in python_execs: - python_implementation = subprocess.check_output( + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) - LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: From 69c5dee2f01b1ae35bd813d31d46429a32cb475d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 1 Jul 2015 09:50:12 -0700 Subject: [PATCH 05/18] [SPARK-7714] [SPARKR] SparkR tests should use more specific expectations than expect_true 1. Update the pattern 'expect_true(a == b)' to 'expect_equal(a, b)'. 2. Update the pattern 'expect_true(inherits(a, b))' to 'expect_is(a, b)'. 3. Update the pattern 'expect_true(identical(a, b))' to 'expect_identical(a, b)'. Author: Sun Rui Closes #7152 from sun-rui/SPARK-7714 and squashes the following commits: 8ad2440 [Sun Rui] Fix test case errors. 8fe9f0c [Sun Rui] Update the pattern 'expect_true(identical(a, b))' to 'expect_identical(a, b)'. f1b8005 [Sun Rui] Update the pattern 'expect_true(inherits(a, b))' to 'expect_is(a, b)'. f631e94 [Sun Rui] Update the pattern 'expect_true(a == b)' to 'expect_equal(a, b)'. --- R/pkg/inst/tests/test_binaryFile.R | 2 +- R/pkg/inst/tests/test_binary_function.R | 4 +- R/pkg/inst/tests/test_includeJAR.R | 4 +- R/pkg/inst/tests/test_parallelize_collect.R | 2 +- R/pkg/inst/tests/test_rdd.R | 4 +- R/pkg/inst/tests/test_sparkSQL.R | 354 ++++++++++---------- R/pkg/inst/tests/test_take.R | 8 +- R/pkg/inst/tests/test_textFile.R | 6 +- R/pkg/inst/tests/test_utils.R | 4 +- 9 files changed, 194 insertions(+), 194 deletions(-) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index 4db7266abc8e2..ccaea18ecab2a 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -82,7 +82,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index a1e354e567be5..3be8c65a6c1a0 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") rdd<- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 8bc693be20c3c..844d86f3cc97f 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -31,7 +31,7 @@ runScript <- function() { test_that("sparkJars tag in SparkContext", { testOutput <- runScript() helloTest <- testOutput[1] - expect_true(helloTest == "Hello, Dave") + expect_equal(helloTest, "Hello, Dave") basicFunction <- testOutput[2] - expect_true(basicFunction == 4L) + expect_equal(basicFunction, "4") }) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db37..2552127cc547f 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 4fe653856756e..fc3c01d837de4 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 6a08f894313c4..0e4235ea8b4b3 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -61,7 +61,7 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) - expect_true(class(testStruct) == "structType") + expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() @@ -73,39 +73,39 @@ test_that("infer types", { test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlContext, rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -150,26 +150,26 @@ test_that("convert NAs to null type in DataFrames", { test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -219,21 +219,21 @@ test_that("create DataFrame with different data types", { test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) + expect_equal(count(rdd), 3) df <- jsonRDD(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) df <- jsonRDD(sqlContext, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { @@ -248,9 +248,9 @@ test_that("test cache, uncache and clearCache", { test_that("test tableNames and tables", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlContext)) == 1) + expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) - expect_true(count(df) == 1) + expect_equal(count(df), 1) dropTempTable(sqlContext, "table1") }) @@ -258,8 +258,8 @@ test_that("registerTempTable() results in a queryable table and sql() results in df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) dropTempTable(sqlContext, "table1") }) @@ -279,14 +279,14 @@ test_that("insertInto() on a registered table", { registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlContext, "select * from table1")) == 5) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlContext, "select * from table1")) == 2) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") }) @@ -294,16 +294,16 @@ test_that("table() returns a new DataFrame", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { @@ -311,9 +311,9 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -333,16 +333,16 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { @@ -352,7 +352,7 @@ test_that("objectFile() works with row serialization", { saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -363,32 +363,32 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { df <- jsonFile(sqlContext, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { @@ -401,9 +401,9 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) @@ -430,36 +430,36 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) }) test_that("distinct() on DataFrames", { @@ -472,15 +472,15 @@ test_that("distinct() on DataFrames", { df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) }) test_that("sample on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) + expect_is(sampled, "DataFrame") sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) @@ -491,15 +491,15 @@ test_that("sample on a DataFrame", { test_that("select operators", { df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -518,50 +518,50 @@ test_that("select operators", { test_that("select with column", { df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) }) test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) }) test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) # Check if we can apply a user defined schema schema <- structType(structField("name", type = "string"), structField("age", type = "double")) df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df1, "DataFrame")) + expect_is(df1, "DataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Run the same with loadDF df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) @@ -569,8 +569,8 @@ test_that("write.df() as parquet file", { df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { @@ -580,17 +580,17 @@ test_that("test HiveContext", { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -643,65 +643,65 @@ test_that("string operators", { test_that("group by", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) # test suites for %in% filtered3 <- filter(df, "age in (19)") @@ -727,29 +727,29 @@ test_that("join() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) @@ -775,50 +775,50 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("mutate() and rename()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("write.df() on DataFrame and works with parquetFile", { df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) @@ -828,8 +828,8 @@ test_that("parquetFile works with multiple input paths", { parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df)*2) }) test_that("describe() on a DataFrame", { @@ -851,58 +851,58 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) row.names(expected) <- row.names(actual) # identical on two dataframes does not work here. Don't know why. # use identical on all columns as a workaround. - expect_true(identical(expected$age, actual$age)) - expect_true(identical(expected$height, actual$height)) - expect_true(identical(expected$name, actual$name)) + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with threshold expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { @@ -915,22 +915,22 @@ test_that("fillna() on a DataFrame", { expected$age[is.na(expected$age)] <- 50 expected$height[is.na(expected$height)] <- 50.6 actual <- collect(fillna(df, 50.6)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$age[is.na(expected$age)] <- 50 actual <- collect(fillna(df, 50.6, "age")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # fill with named list @@ -939,7 +939,7 @@ test_that("fillna() on a DataFrame", { expected$height[is.na(expected$height)] <- 50.6 expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index c5eb417b40159..c2c724cdc762f 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,8 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 092ad9dc10c2e..58318dfef71ab 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 15030e6f1d77e..aa0d2a66b9082 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -45,10 +45,10 @@ test_that("serializeToBytes on RDD", { writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") + expect_equal(getSerializedMode(ser.rdd), "byte") unlink(fileName) }) From 4137f769b84300648ad933b0b3054d69a7316745 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 10:30:54 -0700 Subject: [PATCH 06/18] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types. This patch doesn't actually introduce any code that uses the new ExpectsInputTypes. It just adds the trait so others can use it. Also renamed the old expectsInputTypes function to just inputTypes. We should add implicit type casting also in the future. Author: Reynold Xin Closes #7151 from rxin/expects-input-types and squashes the following commits: 16cf07b [Reynold Xin] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 - .../catalyst/analysis/HiveTypeCoercion.scala | 8 ++--- .../sql/catalyst/expressions/Expression.scala | 29 ++++++++++++++++--- .../spark/sql/catalyst/expressions/math.scala | 6 ++-- .../spark/sql/catalyst/expressions/misc.scala | 8 ++--- .../sql/catalyst/expressions/predicates.scala | 6 ++-- .../expressions/stringOperations.scala | 10 +++---- 7 files changed, 44 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a069b4710f38c..583338da57117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.types._ * Throws user facing errors when passed invalid queries that fail to analyze. */ trait CheckAnalysis { - self: Analyzer => /** * Override to provide additional checks for correct analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a9d396d1faeeb..2ab5cb666fbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -45,7 +45,7 @@ object HiveTypeCoercion { IfCoercion :: Division :: PropagateTypes :: - AddCastForAutoCastInputTypes :: + ImplicitTypeCasts :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -705,13 +705,13 @@ object HiveTypeCoercion { * Casts types according to the expected input types for Expressions that have the trait * [[AutoCastInputTypes]]. */ - object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { + object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => - val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes => + val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b5063f32fa529..e18a3118945e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -265,17 +265,38 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + /** * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait AutoCastInputTypes { - self: Expression => +trait AutoCastInputTypes { self: Expression => - def expectedChildTypes: Seq[DataType] + def inputTypes: Seq[DataType] override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, // so type mismatch error won't be reported here, but for underling `Cast`s. TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index da63f2fa970cf..b51318dd5044c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) override def toString: String = s"$name($left, $right)" @@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia case class Bin(child: Expression) extends UnaryExpression with Serializable with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(LongType) + override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a7bcbe46c339a..407023e472081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -36,7 +36,7 @@ case class Md5(child: Expression) override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression) override def toString: String = s"SHA2($left, $right)" - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) @@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -179,7 +179,7 @@ case class Crc32(child: Expression) override def dataType: DataType = LongType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 98cd5aa8148c4..a777f77add2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -72,7 +72,7 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def toString: String = s"NOT $child" - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType) override def eval(input: InternalRow): Any = { child.eval(input) match { @@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any]) case class And(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "&&" @@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression) case class Or(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "||" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ce184e4f32f18..4cbfc4e084948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) @@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def eval(input: InternalRow): Any = { val leftEval = left.eval(input) @@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) if (str.dataType == BinaryType) str.dataType else StringType } - override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil @@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) */ case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val string = child.eval(input) From 31b4a3d7f2be9053a041e5ae67418562a93d80d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Jul 2015 10:31:35 -0700 Subject: [PATCH 07/18] [SPARK-8621] [SQL] support empty string as column name improve the empty check in `parseAttributeName` so that we can allow empty string as column name. Close https://github.com/apache/spark/pull/7117 Author: Wenchen Fan Closes #7149 from cloud-fan/8621 and squashes the following commits: efa9e3e [Wenchen Fan] support empty string --- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 4 ++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b009a200b920f..e911b907e8536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -161,7 +161,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { if (tmp.nonEmpty) throw e inBacktick = true } else if (char == '.') { - if (tmp.isEmpty) throw e + if (name(i - 1) == '.' || i == name.length - 1) throw e nameParts += tmp.mkString tmp.clear() } else { @@ -170,7 +170,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } i += 1 } - if (tmp.isEmpty || inBacktick) throw e + if (inBacktick) throw e nameParts += tmp.mkString nameParts.toSeq } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50d324c0686fa..afb1cf5f8d1cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -730,4 +730,11 @@ class DataFrameSuite extends QueryTest { val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } + + test("SPARK-8621: support empty string column name") { + val df = Seq(Tuple1(1)).toDF("").as("t") + // We should allow empty string as column name + df.col("") + df.col("t.``") + } } From 184de91d15a4bfc5c014e8cf86211874bba4593f Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 1 Jul 2015 11:14:07 -0700 Subject: [PATCH 08/18] [SPARK-6263] [MLLIB] Python MLlib API missing items: Utils Implement missing API in pyspark. MLUtils * appendBias * loadVectors `kFold` is also missing however I am not sure `ClassTag` can be passed or restored through python. Author: lewuathe Closes #5707 from Lewuathe/SPARK-6263 and squashes the following commits: 16863ea [lewuathe] Merge master 3fc27e7 [lewuathe] Merge branch 'master' into SPARK-6263 6084e9c [lewuathe] Resolv conflict d2aa2a0 [lewuathe] Resolv conflict 9c329d8 [lewuathe] Fix efficiency 3a12a2d [lewuathe] Merge branch 'master' into SPARK-6263 1d4714b [lewuathe] Fix style b29e2bc [lewuathe] Remove scipy dependencies e32eb40 [lewuathe] Merge branch 'master' into SPARK-6263 25d3c9d [lewuathe] Remove unnecessary imports 7ec04db [lewuathe] Resolv conflict 1502d13 [lewuathe] Resolv conflict d6bd416 [lewuathe] Check existence of scipy.sparse 5d555b1 [lewuathe] Construct scipy.sparse matrix c345a44 [lewuathe] Merge branch 'master' into SPARK-6263 b8b5ef7 [lewuathe] Fix unnecessary sort method d254be7 [lewuathe] Merge branch 'master' into SPARK-6263 62a9c7e [lewuathe] Fix appendBias return type 454c73d [lewuathe] Merge branch 'master' into SPARK-6263 a353354 [lewuathe] Remove unnecessary appendBias implementation 44295c2 [lewuathe] Merge branch 'master' into SPARK-6263 64f72ad [lewuathe] Merge branch 'master' into SPARK-6263 c728046 [lewuathe] Fix style 2980569 [lewuathe] [SPARK-6263] Python MLlib API missing items: Utils --- .../mllib/api/python/PythonMLLibAPI.scala | 9 ++++ python/pyspark/mllib/tests.py | 43 +++++++++++++++++++ python/pyspark/mllib/util.py | 22 ++++++++++ 3 files changed, 74 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a66a404d5c846..458fab48fef5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -75,6 +75,15 @@ private[python] class PythonMLLibAPI extends Serializable { minPartitions: Int): JavaRDD[LabeledPoint] = MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) + /** + * Loads and serializes vectors saved with `RDD#saveAsTextFile`. + * @param jsc Java SparkContext + * @param path file or directory path in any Hadoop-supported file system URI + * @return serialized vectors in a RDD + */ + def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] = + MLUtils.loadVectors(jsc.sc, path) + private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f0091d6faccce..49ce125de7e78 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,6 +54,7 @@ from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext @@ -1290,6 +1291,48 @@ def func(rdd): self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 348238319e407..875d3b2d642c6 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None): minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) + @staticmethod + def appendBias(data): + """ + Returns a new vector with `1.0` (bias) appended to + the end of the input vector. + """ + vec = _convert_to_vector(data) + if isinstance(vec, SparseVector): + newIndices = np.append(vec.indices, len(vec)) + newValues = np.append(vec.values, 1.0) + return SparseVector(len(vec) + 1, newIndices, newValues) + else: + return _convert_to_vector(np.append(vec.toArray(), 1.0)) + + @staticmethod + def loadVectors(sc, path): + """ + Loads vectors saved using `RDD[Vector].saveAsTextFile` + with the default number of partitions. + """ + return callMLlibFunc("loadVectors", sc, path) + class Saveable(object): """ From 2012913355993e6516e4c81dbc92e579977131da Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 1 Jul 2015 11:17:56 -0700 Subject: [PATCH 09/18] [SPARK-8308] [MLLIB] add missing save load for python example jira: https://issues.apache.org/jira/browse/SPARK-8308 1. add some missing save/load in python examples. , LogisticRegression, LinearRegression and NaiveBayes 2. tune down iterations for MatrixFactorization, since current number will trigger StackOverflow for default java configuration (>1M) Author: Yuhao Yang Closes #6760 from hhbyyh/docUpdate and squashes the following commits: 9bd3383 [Yuhao Yang] update scala example 8a44692 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into docUpdate 077cbb8 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into docUpdate 3e948dc [Yuhao Yang] add missing save load for python example --- docs/mllib-collaborative-filtering.md | 6 +++--- docs/mllib-linear-methods.md | 12 ++++++++++-- docs/mllib-naive-bayes.md | 6 +++++- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index dfdf6216b270c..eedc23424ad54 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -77,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -149,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -210,7 +210,7 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 2a2a7c13186d8..3927d65fbf8fb 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -499,7 +499,7 @@ Note that the Python API does not yet support multiclass classification and mode will in the future. {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint from numpy import array @@ -518,6 +518,10 @@ model = LogisticRegressionWithLBFGS.train(parsedData) labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %} @@ -668,7 +672,7 @@ values. We compute the mean squared error at the end to evaluate Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel from numpy import array # Load and parse the data @@ -686,6 +690,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index bf6d124fd5d8d..e73bd30f3a90a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -119,7 +119,7 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint @@ -140,6 +140,10 @@ model = NaiveBayes.train(training, 1.0) # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} From b8faa32875aa560cdce340266d898902a920418d Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 1 Jul 2015 11:57:52 -0700 Subject: [PATCH 10/18] [SPARK-8765] [MLLIB] [PYTHON] removed flaky python PIC test See failure: [https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/36133/console] CC yanboliang mengxr Author: Joseph K. Bradley Closes #7164 from jkbradley/pic-python-test and squashes the following commits: 156d55b [Joseph K. Bradley] removed flaky python PIC test --- python/pyspark/mllib/clustering.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e3c8a24c4a751..a3eab635282f6 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -288,16 +288,12 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = PowerIterationClustering.train(rdd, 2, 100) >>> model.k 2 - >>> sorted(model.assignments().collect()) - [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = PowerIterationClusteringModel.load(sc, path) >>> sameModel.k 2 - >>> sorted(sameModel.assignments().collect()) - [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... >>> from shutil import rmtree >>> try: ... rmtree(path) From 75b9fe4c5ff6f206c6fc9100563d625b39f142ba Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 1 Jul 2015 11:59:24 -0700 Subject: [PATCH 11/18] [SPARK-8378] [STREAMING] Add the Python API for Flume Author: zsxwing Closes #6830 from zsxwing/flume-python and squashes the following commits: 78dfdac [zsxwing] Fix the compile error in the test code f1bf3c0 [zsxwing] Address TD's comments 0449723 [zsxwing] Add sbt goal streaming-flume-assembly/assembly e93736b [zsxwing] Fix the test case for determine_modules_to_test 9d5821e [zsxwing] Fix pyspark_core dependencies f9ee681 [zsxwing] Merge branch 'master' into flume-python 7a55837 [zsxwing] Add streaming_flume_assembly to run-tests.py b96b0de [zsxwing] Merge branch 'master' into flume-python ce85e83 [zsxwing] Fix incompatible issues for Python 3 01cbb3d [zsxwing] Add import sys 152364c [zsxwing] Fix the issue that StringIO doesn't work in Python 3 14ba0ff [zsxwing] Add flume-assembly for sbt building b8d5551 [zsxwing] Merge branch 'master' into flume-python 4762c34 [zsxwing] Fix the doc 0336579 [zsxwing] Refactor Flume unit tests and also add tests for Python API 9f33873 [zsxwing] Add the Python API for Flume --- dev/run-tests.py | 7 +- dev/sparktestsupport/modules.py | 15 +- docs/streaming-flume-integration.md | 18 ++ docs/streaming-programming-guide.md | 2 +- .../main/python/streaming/flume_wordcount.py | 55 +++++ external/flume-assembly/pom.xml | 135 +++++++++++ .../streaming/flume/FlumeTestUtils.scala | 116 ++++++++++ .../spark/streaming/flume/FlumeUtils.scala | 76 ++++++- .../flume/PollingFlumeTestUtils.scala | 209 ++++++++++++++++++ .../flume/FlumePollingStreamSuite.scala | 173 +++------------ .../streaming/flume/FlumeStreamSuite.scala | 106 ++------- pom.xml | 1 + project/SparkBuild.scala | 6 +- python/pyspark/streaming/flume.py | 147 ++++++++++++ python/pyspark/streaming/tests.py | 179 ++++++++++++++- 15 files changed, 1009 insertions(+), 236 deletions(-) create mode 100644 examples/src/main/python/streaming/flume_wordcount.py create mode 100644 external/flume-assembly/pom.xml create mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala create mode 100644 external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala create mode 100644 python/pyspark/streaming/flume.py diff --git a/dev/run-tests.py b/dev/run-tests.py index 4596e07014733..1f0d218514f92 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -96,8 +96,8 @@ def determine_modules_to_test(changed_modules): ['examples', 'graphx'] >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) >>> x # doctest: +NORMALIZE_WHITESPACE - ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \ - 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql'] + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be @@ -293,7 +293,8 @@ def build_spark_sbt(hadoop_version): build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "assembly/assembly", - "streaming-kafka-assembly/assembly"] + "streaming-kafka-assembly/assembly", + "streaming-flume-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index efe3a897e9c10..993583e2f4119 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -203,7 +203,7 @@ def contains_file(self, filename): streaming_flume = Module( - name="streaming_flume", + name="streaming-flume", dependencies=[streaming], source_file_regexes=[ "external/flume", @@ -214,6 +214,15 @@ def contains_file(self, filename): ) +streaming_flume_assembly = Module( + name="streaming-flume-assembly", + dependencies=[streaming_flume, streaming_flume_sink], + source_file_regexes=[ + "external/flume-assembly", + ] +) + + mllib = Module( name="mllib", dependencies=[streaming, sql], @@ -241,7 +250,7 @@ def contains_file(self, filename): pyspark_core = Module( name="pyspark-core", - dependencies=[mllib, streaming, streaming_kafka], + dependencies=[], source_file_regexes=[ "python/(?!pyspark/(ml|mllib|sql|streaming))" ], @@ -281,7 +290,7 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka], + dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 8d6e74370918f..de0461010daec 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -58,6 +58,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). +
+ from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
Note that the hostname should be the same as the one used by the resource manager in the @@ -135,6 +144,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
+ from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index b784d59666fec..e72d5580dae55 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 0000000000000..091b64d8c4af4 --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ + spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 0000000000000..8565cd83edfa2 --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,135 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + org.apache.avro + avro + ${avro.version} + + + org.apache.avro + avro-ipc + ${avro.version} + + + io.netty + netty + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + org.mortbay.jetty + servlet-api + + + org.apache.velocity + velocity + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 0000000000000..9d9c3b189415f --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.toList) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 44dec45c227ca..095bfb0c73a9a 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,10 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConversions._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -236,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.length == ports.length) + val addresses = hosts.zip(ports).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 0000000000000..91d63d49dbec3 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): JList[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Map[String, String](s"test-$t" -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d772b9ca9b570..d5f9a0aa38f9f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,47 +18,33 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder -import org.scalatest.concurrent.Eventually._ - +import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -89,146 +75,55 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - writeAndVerify(Seq(sink), Seq(channel)) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + try { + val port = utils.startSingleSink() + + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() try { - writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() + utils.close() } } - def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 - } - } - assert(counter === totalEventsPerChannel * channels.size) - } - ssc.stop() - } - - def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach + try { + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.map { + case kv => (kv._1.toString, kv._2.toString) + }).map(mapAsJavaMap) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index c926359987d89..5bc4cdf65306c 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,20 +17,12 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -41,22 +33,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.util.Utils class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,19 +49,29 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -98,58 +88,6 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) - output should be (input) - } - } - /** Class to create socket channel with compression */ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { diff --git a/pom.xml b/pom.xml index 94dd512cfb618..211da9ee74a3f 100644 --- a/pom.xml +++ b/pom.xml @@ -102,6 +102,7 @@ external/twitter external/flume external/flume-sink + external/flume-assembly external/mqtt external/zeromq examples diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f5f1c9a1a247a..4ef4dc8bdc039 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -347,7 +347,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py new file mode 100644 index 0000000000000..cbb573f226bbe --- /dev/null +++ b/python/pyspark/streaming/flume.py @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version >= "3": + from io import BytesIO +else: + from StringIO import StringIO +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int +from pyspark.streaming import DStream + +__all__ = ['FlumeUtils', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class FlumeUtils(object): + + @staticmethod + def createStream(ssc, hostname, port, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + enableDecompression=False, + bodyDecoder=utf8_decoder): + """ + Create an input stream that pulls events from Flume. + + :param ssc: StreamingContext object + :param hostname: Hostname of the slave machine to which the flume data will be sent + :param port: Port of the slave machine to which the flume data will be sent + :param storageLevel: Storage level to use for storing the received objects + :param enableDecompression: Should netty server decompress input stream + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def createPollingStream(ssc, addresses, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + maxBatchSize=1000, + parallelism=5, + bodyDecoder=utf8_decoder): + """ + Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + This stream will poll the sink for data and will pull events as they are available. + + :param ssc: StreamingContext object + :param addresses: List of (host, port)s on which the Spark Sink is running. + :param storageLevel: Storage level to use for storing the received objects + :param maxBatchSize: The maximum number of events to be pulled from the Spark sink + in a single RPC call + :param parallelism: Number of concurrent requests this stream should send to the sink. + Note that having a higher number of requests concurrently being pulled + will result in this stream using more threads + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + hosts = [] + ports = [] + for (host, port) in addresses: + hosts.append(host) + ports.append(port) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createPollingStream( + ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def _toPythonDStream(ssc, jstream, bodyDecoder): + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + + def func(event): + headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) + headers = {} + strSer = UTF8Deserializer() + for i in range(0, read_int(headersBytes)): + key = strSer.loads(headersBytes) + value = strSer.loads(headersBytes) + headers[key] = value + body = bodyDecoder(event[1]) + return (headers, body) + return stream.map(func) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Flume libraries not found in class path. Try one of the following. + + 1. Include the Flume library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 91ce681fbe169..188c8ff12067e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -38,6 +38,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition +from pyspark.streaming.flume import FlumeUtils class PySparkStreamingTestCase(unittest.TestCase): @@ -677,7 +678,156 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) -if __name__ == "__main__": + +class FlumeStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(FlumeStreamTests, self).setUp() + + utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + super(FlumeStreamTests, self).tearDown() + + def _startContext(self, n, compressed): + # Start the StreamingContext and also collect the result + dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), + enableDecompression=compressed) + result = [] + + def get_output(_, rdd): + for event in rdd.collect(): + if len(result) < n: + result.append(event) + dstream.foreachRDD(get_output) + self.ssc.start() + return result + + def _validateResult(self, input, result): + # Validate both the header and the body + header = {"test": "header"} + self.assertEqual(len(input), len(result)) + for i in range(0, len(input)): + self.assertEqual(header, result[i][0]) + self.assertEqual(input[i], result[i][1]) + + def _writeInput(self, input, compressed): + # Try to write input to the receiver until success or timeout + start_time = time.time() + while True: + try: + self._utils.writeInput(input, compressed) + break + except: + if time.time() - start_time < self.timeout: + time.sleep(0.01) + else: + raise + + def test_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), False) + self._writeInput(input, False) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + def test_compressed_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), True) + self._writeInput(input, True) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + +class FlumePollingStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + maxAttempts = 5 + + def setUp(self): + utilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + def _writeAndVerify(self, ports): + # Set up the streaming context and input streams + ssc = StreamingContext(self.sc, self.duration) + try: + addresses = [("localhost", port) for port in ports] + dstream = FlumeUtils.createPollingStream( + ssc, + addresses, + maxBatchSize=self._utils.eventsPerBatch(), + parallelism=5) + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + dstream.foreachRDD(get_output) + ssc.start() + self._utils.sendDatAndEnsureAllDataHasBeenReceived() + + self.wait_for(outputBuffer, self._utils.getTotalEvents()) + outputHeaders = [event[0] for event in outputBuffer] + outputBodies = [event[1] for event in outputBuffer] + self._utils.assertOutput(outputHeaders, outputBodies) + finally: + ssc.stop(False) + + def _testMultipleTimes(self, f): + attempt = 0 + while True: + try: + f() + break + except: + attempt += 1 + if attempt >= self.maxAttempts: + raise + else: + import traceback + traceback.print_exc() + + def _testFlumePolling(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def _testFlumePollingMultipleHosts(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def test_flume_polling(self): + self._testMultipleTimes(self._testFlumePolling) + + def test_flume_polling_multiple_hosts(self): + self._testMultipleTimes(self._testFlumePollingMultipleHosts) + + +def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") jars = glob.glob( @@ -692,5 +842,30 @@ def test_kafka_rdd_with_leaders(self): raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " "remove all but one") % kafka_assembly_dir) else: - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0] + return jars[0] + + +def search_flume_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") + jars = glob.glob( + os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + +if __name__ == "__main__": + kafka_assembly_jar = search_kafka_assembly_jar() + flume_assembly_jar = search_flume_assembly_jar() + jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From 9f7db3486fcb403cae8da9dfce8978373c3f47b7 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 1 Jul 2015 12:33:24 -0700 Subject: [PATCH 12/18] [SPARK-7820] [BUILD] Fix Java8-tests suite compile and test error under sbt Author: jerryshao Closes #7120 from jerryshao/SPARK-7820 and squashes the following commits: 6902439 [jerryshao] fix Java8-tests suite compile error under sbt --- extras/java8-tests/pom.xml | 8 ++++++++ project/SparkBuild.scala | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index f138251748c9e..3636a9037d43f 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4ef4dc8bdc039..5f389bcc9ceeb 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -161,7 +161,7 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExludedDependencies.settings ++ Revolver.settings)) + .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -246,7 +246,7 @@ object Flume { This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. */ -object ExludedDependencies { +object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } ) From 3083e17645e4b707646fe48e406e02c156a0f37b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Jul 2015 12:39:57 -0700 Subject: [PATCH 13/18] [QUICKFIX] [SQL] fix copy of generated row copy() of generated Row doesn't check nullability of columns Author: Davies Liu Closes #7163 from davies/fix_copy and squashes the following commits: 661a206 [Davies Liu] fix copy of generated row --- .../sql/catalyst/expressions/codegen/GenerateProjection.scala | 2 +- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 5be47175fa7f1..3c7ee9cc16599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -148,7 +148,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n") val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""arr[$i] = c$i;""" + s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 7d95ef7f710af..3171caf6ad77f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -136,6 +136,9 @@ trait ExpressionEvalHelper { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } + if (actual.copy() != expectedRow) { + fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") + } } protected def checkEvaluationWithOptimization( From 1ce6428907b4ddcf52dbf0c86196d82ab7392442 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 1 Jul 2015 20:40:47 +0100 Subject: [PATCH 14/18] [SPARK-3444] [CORE] Restore INFO level after log4j test. Otherwise other tests don't log anything useful... Author: Marcelo Vanzin Closes #7140 from vanzin/SPARK-3444 and squashes the following commits: de14836 [Marcelo Vanzin] Better fix. 6cff13a [Marcelo Vanzin] [SPARK-3444] [core] Restore INFO level after log4j test. --- .../scala/org/apache/spark/util/UtilsSuite.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index baa4c661cc21e..251a797dc28a2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -486,11 +486,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // Test for using the util function to change our log levels. test("log4j log level change") { - Utils.setLogLevel(org.apache.log4j.Level.ALL) - assert(log.isInfoEnabled()) - Utils.setLogLevel(org.apache.log4j.Level.ERROR) - assert(!log.isInfoEnabled()) - assert(log.isErrorEnabled()) + val current = org.apache.log4j.Logger.getRootLogger().getLevel() + try { + Utils.setLogLevel(org.apache.log4j.Level.ALL) + assert(log.isInfoEnabled()) + Utils.setLogLevel(org.apache.log4j.Level.ERROR) + assert(!log.isInfoEnabled()) + assert(log.isErrorEnabled()) + } finally { + // Best effort at undoing changes this test made. + Utils.setLogLevel(current) + } } test("deleteRecursively") { From f958f27e2056f9e380373c2807d8bb5977ecf269 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Jul 2015 16:43:18 -0700 Subject: [PATCH 15/18] [SPARK-8766] support non-ascii character in column names Use UTF-8 to encode the name of column in Python 2, or it may failed to encode with default encoding ('ascii'). This PR also fix a bug when there is Java exception without error message. Author: Davies Liu Closes #7165 from davies/non_ascii and squashes the following commits: 02cb61a [Davies Liu] fix tests 3b09d31 [Davies Liu] add encoding in header 867754a [Davies Liu] support non-ascii character in column names --- python/pyspark/sql/dataframe.py | 3 +-- python/pyspark/sql/tests.py | 9 +++++++++ python/pyspark/sql/types.py | 2 ++ python/pyspark/sql/utils.py | 6 +++--- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4b9efa0a210fb..273a40dd526cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -484,13 +484,12 @@ def dtypes(self): return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property - @ignore_unicode_prefix @since(1.3) def columns(self): """Returns all column names as a list. >>> df.columns - [u'age', u'name'] + ['age', 'name'] """ return [f.name for f in self.schema.fields] diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5af2ce09bc122..333378c7f1854 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1,3 +1,4 @@ +# -*- encoding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -628,6 +629,14 @@ def test_access_column(self): self.assertRaises(IndexError, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) + def test_column_name_with_non_ascii(self): + df = self.sqlCtx.createDataFrame([(1,)], ["数量"]) + self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema) + self.assertEqual("DataFrame[数量: bigint]", str(df)) + self.assertEqual([("数量", 'bigint')], df.dtypes) + self.assertEqual(1, df.select("数量").first()[0]) + self.assertEqual(1, df.select(df["数量"]).first()[0]) + def test_access_nested_types(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ae9344e6106a4..160df40d65cc1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -324,6 +324,8 @@ def __init__(self, name, dataType, nullable=True, metadata=None): False """ assert isinstance(dataType, DataType), "dataType should be DataType" + if not isinstance(name, str): + name = name.encode('utf-8') self.name = name self.dataType = dataType self.nullable = nullable diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8096802e7302f..cc5b2c088b7cc 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -29,9 +29,9 @@ def deco(*a, **kw): try: return f(*a, **kw) except py4j.protocol.Py4JJavaError as e: - cls, msg = e.java_exception.toString().split(': ', 1) - if cls == 'org.apache.spark.sql.AnalysisException': - raise AnalysisException(msg) + s = e.java_exception.toString() + if s.startswith('org.apache.spark.sql.AnalysisException: '): + raise AnalysisException(s.split(': ', 1)[1]) raise return deco From 272778999823ed79af92280350c5869a87a21f29 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 16:56:48 -0700 Subject: [PATCH 16/18] [SPARK-8770][SQL] Create BinaryOperator abstract class. Our current BinaryExpression abstract class is not for generic binary expressions, i.e. it requires left/right children to have the same type. However, due to its name, contributors build new binary expressions that don't have that assumption (e.g. Sha) and still extend BinaryExpression. This patch creates a new BinaryOperator abstract class, and update the analyzer o only apply type casting rule there. This patch also adds the notion of "prettyName" to expressions, which defines the user-facing name for the expression. Author: Reynold Xin Closes #7170 from rxin/binaryoperator and squashes the following commits: 51264a5 [Reynold Xin] [SPARK-8770][SQL] Create BinaryOperator abstract class. --- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../expressions/ExpectsInputTypes.scala | 59 +++++++ .../sql/catalyst/expressions/Expression.scala | 161 +++++++++--------- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 6 - .../sql/catalyst/expressions/arithmetic.scala | 14 +- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 2 - .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 2 - .../expressions/stringOperations.scala | 26 +-- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- 12 files changed, 170 insertions(+), 135 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2ab5cb666fbcd..8420c54f7c335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,6 +150,7 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ + // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -159,19 +160,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + case b @ BinaryOperator(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + case b @ BinaryOperator(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryExpression(StringNaN, right @ FloatType()) => + case b @ BinaryOperator(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryExpression(left @ FloatType(), StringNaN) => + case b @ BinaryOperator(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryExpression(left @ StringNaN, StringNaN) => + case b @ BinaryOperator(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -245,12 +246,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryExpressions. + // Also widen types for BinaryOperator. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -478,7 +479,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala new file mode 100644 index 0000000000000..450fc4165f93b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.DataType + + +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait AutoCastInputTypes { self: Expression => + + def inputTypes: Seq[DataType] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e18a3118945e8..cafbbafdca207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,17 +119,6 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns a string representation of this expression that does not have developer centric - * debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -154,71 +143,40 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess -} - -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable - - override def toString: String = s"($left $symbol $right)" /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f accepts two variable names and returns Java code to compute the output. + * Returns a user-facing string representation of this expression's name. + * This should usually match the name of the function in SQL. */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def prettyName: String = getClass.getSimpleName.toLowerCase /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Returns a user-facing string representation of this expression, i.e. does not have developer + * centric debugging information like the expression id. */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } -} -private[sql] object BinaryExpression { - def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def toString: String = prettyName + children.mkString("(", ",", ")") } + +/** + * A leaf expression, i.e. one without any child expressions. + */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } + +/** + * An expression with one input and one output. The output is by default evaluated to null + * if the input is evaluated to null. + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -265,39 +223,76 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } + /** - * An trait that gets mixin to define the expected input types of an expression. + * An expression with two inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. */ -trait ExpectsInputTypes { self: Expression => +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + * @param f accepts two variable names and returns Java code to compute the output. */ - def inputTypes: Seq[Any] + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } } + /** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. + * An expression that has two inputs that are expected to the be same type. If the two inputs have + * different types, the analyzer will find the tightest common type and do the proper type casting. */ -trait AutoCastInputTypes { self: Expression => +abstract class BinaryOperator extends BinaryExpression { + self: Product => - def inputTypes: Seq[DataType] + def symbol: String - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } + override def toString: String = s"($left $symbol $right)" +} + + +private[sql] object BinaryOperator { + def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ebabb6f117851..caf021b016a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index a9fc54c548f49..da520f56b430e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -128,7 +128,6 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -162,7 +161,6 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -401,8 +399,6 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -494,8 +490,6 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString: String = s"SUM($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5363b3556886a..4fbf4c87009c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def toString: String = s"positive($child)" + override def prettyName: String = "positive" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,8 +69,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { - override def toString: String = s"Abs($child)" - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -79,10 +77,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryExpression { +abstract class BinaryArithmetic extends BinaryOperator { self: Product => - override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -360,7 +357,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - override def toString: String = s"MaxOf($left, $right)" + + override def symbol: String = "max" + override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -413,5 +412,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def toString: String = s"MinOf($left, $right)" + override def symbol: String = "min" + override def prettyName: String = symbol } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5def57b067424..67e7dc4ec8b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString: String = s"Array(${children.mkString(",")})" + override def prettyName: String = "array" } /** @@ -71,4 +71,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } + + override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 78be2824347d7..145d323a9f0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,8 +38,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a777f77add2db..34df89a163895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryExpression with Predicate { +abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index daa9f4403ffab..5d51a4ca65332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,8 +137,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType - override def symbol: String = "++=" - override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 4cbfc4e084948..b020f2bbc5818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,8 +75,6 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "LIKE" - // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -101,14 +99,16 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,9 +134,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase() - - override def toString: String = s"Upper($child)" + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -148,9 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase() - - override def toString: String = s"Lower($child)" + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -178,8 +174,6 @@ trait StringComparison extends AutoCastInputTypes { } } - override def symbol: String = nodeName - override def toString: String = s"$nodeName($left, $right)" } @@ -284,12 +278,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } - - override def toString: String = len match { - // TODO: This is broken because max is not an integer value. - case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" - case _ => s"SUBSTR($str, $pos, $len)" - } } /** @@ -304,9 +292,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } - override def toString: String = s"length($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } + + override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bda217935cb05..86792f0217572 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryExpression => actual.append(b.symbol); + case b: BinaryOperator => actual.append(b.symbol); case l: Literal => actual.append(l.toString); } From 3a342dedc04799948bf6da69843bd1a91202ffe5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 16:59:39 -0700 Subject: [PATCH 17/18] Revert "[SPARK-8770][SQL] Create BinaryOperator abstract class." This reverts commit 272778999823ed79af92280350c5869a87a21f29. --- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../expressions/ExpectsInputTypes.scala | 59 ------- .../sql/catalyst/expressions/Expression.scala | 161 +++++++++--------- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 6 + .../sql/catalyst/expressions/arithmetic.scala | 14 +- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 2 + .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 2 + .../expressions/stringOperations.scala | 26 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- 12 files changed, 135 insertions(+), 170 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8420c54f7c335..2ab5cb666fbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,7 +150,6 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ - // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -160,19 +159,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryOperator(StringNaN, right @ DoubleType()) => + case b @ BinaryExpression(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryOperator(left @ DoubleType(), StringNaN) => + case b @ BinaryExpression(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryOperator(StringNaN, right @ FloatType()) => + case b @ BinaryExpression(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryOperator(left @ FloatType(), StringNaN) => + case b @ BinaryExpression(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryOperator(left @ StringNaN, StringNaN) => + case b @ BinaryExpression(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -246,12 +245,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryOperator. + // Also widen types for BinaryExpressions. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -479,7 +478,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + case b @ BinaryExpression(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala deleted file mode 100644 index 450fc4165f93b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.types.DataType - - -/** - * An trait that gets mixin to define the expected input types of an expression. - */ -trait ExpectsInputTypes { self: Expression => - - /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. - * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). - */ - def inputTypes: Seq[Any] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess - } -} - -/** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. - */ -trait AutoCastInputTypes { self: Expression => - - def inputTypes: Seq[DataType] - - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cafbbafdca207..e18a3118945e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,6 +119,17 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + /** + * Returns a string representation of this expression that does not have developer centric + * debugging information like the expression id. + */ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString + } + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -143,40 +154,71 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess +} + +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable + + override def toString: String = s"($left $symbol $right)" /** - * Returns a user-facing string representation of this expression's name. - * This should usually match the name of the function in SQL. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. */ - def prettyName: String = getClass.getSimpleName.toLowerCase + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } /** - * Returns a user-facing string representation of this expression, i.e. does not have developer - * centric debugging information like the expression id. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } - - override def toString: String = prettyName + children.mkString("(", ",", ")") } +private[sql] object BinaryExpression { + def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) +} -/** - * A leaf expression, i.e. one without any child expressions. - */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } - -/** - * An expression with one input and one output. The output is by default evaluated to null - * if the input is evaluated to null. - */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -223,76 +265,39 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } - /** - * An expression with two inputs and one output. The output is by default evaluated to null - * if any input is evaluated to null. + * An trait that gets mixin to define the expected input types of an expression. */ -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable +trait ExpectsInputTypes { self: Expression => /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. * - * @param f accepts two variable names and returns Java code to compute the output. + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def inputTypes: Seq[Any] - /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess } } - /** - * An expression that has two inputs that are expected to the be same type. If the two inputs have - * different types, the analyzer will find the tightest common type and do the proper type casting. + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. */ -abstract class BinaryOperator extends BinaryExpression { - self: Product => +trait AutoCastInputTypes { self: Expression => - def symbol: String + def inputTypes: Seq[DataType] - override def toString: String = s"($left $symbol $right)" -} - - -private[sql] object BinaryOperator { - def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index caf021b016a41..ebabb6f117851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"UDF(${children.mkString(",")})" + override def toString: String = s"scalaUDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index da520f56b430e..a9fc54c548f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -128,6 +128,7 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType + override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -161,6 +162,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType + override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -399,6 +401,8 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } + override def toString: String = s"AVG($child)" + override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -490,6 +494,8 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } + override def toString: String = s"SUM($child)" + override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4fbf4c87009c2..5363b3556886a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def prettyName: String = "positive" + override def toString: String = s"positive($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { + override def toString: String = s"Abs($child)" + override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -77,9 +79,10 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryExpression { self: Product => + override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -357,9 +360,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - - override def symbol: String = "max" - override def prettyName: String = symbol + override def toString: String = s"MaxOf($left, $right)" } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -412,6 +413,5 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def symbol: String = "min" - override def prettyName: String = symbol + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67e7dc4ec8b14..5def57b067424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def prettyName: String = "array" + override def toString: String = s"Array(${children.mkString(",")})" } /** @@ -71,6 +71,4 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } - - override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 145d323a9f0bb..78be2824347d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,6 +38,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } + override def toString: String = s"Coalesce(${children.mkString(",")})" + override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 34df89a163895..a777f77add2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryOperator with Predicate with AutoCastInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryOperator with Predicate { +abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 5d51a4ca65332..daa9f4403ffab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,6 +137,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType + override def symbol: String = "++=" + override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b020f2bbc5818..4cbfc4e084948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,6 +75,8 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { + override def symbol: String = "LIKE" + // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -99,16 +101,14 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() - - override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { + override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) - override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,7 +134,9 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase + override def convert(v: UTF8String): UTF8String = v.toUpperCase() + + override def toString: String = s"Upper($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -146,7 +148,9 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase + override def convert(v: UTF8String): UTF8String = v.toLowerCase() + + override def toString: String = s"Lower($child)" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -174,6 +178,8 @@ trait StringComparison extends AutoCastInputTypes { } } + override def symbol: String = nodeName + override def toString: String = s"$nodeName($left, $right)" } @@ -278,6 +284,12 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } + + override def toString: String = len match { + // TODO: This is broken because max is not an integer value. + case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" + case _ => s"SUBSTR($str, $pos, $len)" + } } /** @@ -292,9 +304,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } + override def toString: String = s"length($child)" + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } - - override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 86792f0217572..bda217935cb05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryOperator => actual.append(b.symbol); b + case b: BinaryExpression => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryOperator => actual.append(b.symbol); b + case b: BinaryExpression => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryOperator => actual.append(b.symbol); + case b: BinaryExpression => actual.append(b.symbol); case l: Literal => actual.append(l.toString); } From 9fd13d5613b6d16a78d97d4798f085b56107d343 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 21:14:13 -0700 Subject: [PATCH 18/18] [SPARK-8770][SQL] Create BinaryOperator abstract class. Our current BinaryExpression abstract class is not for generic binary expressions, i.e. it requires left/right children to have the same type. However, due to its name, contributors build new binary expressions that don't have that assumption (e.g. Sha) and still extend BinaryExpression. This patch creates a new BinaryOperator abstract class, and update the analyzer o only apply type casting rule there. This patch also adds the notion of "prettyName" to expressions, which defines the user-facing name for the expression. Author: Reynold Xin Closes #7174 from rxin/binary-opterator and squashes the following commits: f31900d [Reynold Xin] [SPARK-8770][SQL] Create BinaryOperator abstract class. fceb216 [Reynold Xin] Merge branch 'master' of github.com:apache/spark into binary-opterator d8518cf [Reynold Xin] Updated Python tests. --- python/pyspark/sql/dataframe.py | 10 +- python/pyspark/sql/functions.py | 4 +- python/pyspark/sql/group.py | 24 +-- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../expressions/ExpectsInputTypes.scala | 59 +++++++ .../sql/catalyst/expressions/Expression.scala | 161 +++++++++--------- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../sql/catalyst/expressions/aggregates.scala | 9 +- .../sql/catalyst/expressions/arithmetic.scala | 14 +- .../expressions/complexTypeCreator.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 2 - .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/expressions/sets.scala | 2 - .../expressions/stringOperations.scala | 26 +-- .../sql/catalyst/trees/TreeNodeSuite.scala | 6 +- 15 files changed, 191 insertions(+), 155 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 273a40dd526cf..1e9c657cf81b3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -802,11 +802,11 @@ def groupBy(self, *cols): Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ @@ -864,10 +864,10 @@ def agg(self, *exprs): (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() - [Row(MAX(age)=5)] + [Row(max(age)=5)] >>> from pyspark.sql import functions as F >>> df.agg(F.min(df.age)).collect() - [Row(MIN(age)=2)] + [Row(min(age)=2)] """ return self.groupBy().agg(*exprs) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4e2be88e9e3b9..f9a15d4a66309 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -266,7 +266,7 @@ def coalesce(*cols): >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() +-------------+ - |Coalesce(a,b)| + |coalesce(a,b)| +-------------+ | null| | 1| @@ -275,7 +275,7 @@ def coalesce(*cols): >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() +----+----+---------------+ - | a| b|Coalesce(a,0.0)| + | a| b|coalesce(a,0.0)| +----+----+---------------+ |null|null| 0.0| | 1|null| 1.0| diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 5a37a673ee80c..04594d5a836ce 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -75,11 +75,11 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] + [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] + [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -110,9 +110,9 @@ def mean(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().mean('age').collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df3.groupBy().mean('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] + [Row(avg(age)=3.5, avg(height)=82.5)] """ @df_varargs_api @@ -125,9 +125,9 @@ def avg(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().avg('age').collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df3.groupBy().avg('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] + [Row(avg(age)=3.5, avg(height)=82.5)] """ @df_varargs_api @@ -136,9 +136,9 @@ def max(self, *cols): """Computes the max value for each numeric columns for each group. >>> df.groupBy().max('age').collect() - [Row(MAX(age)=5)] + [Row(max(age)=5)] >>> df3.groupBy().max('age', 'height').collect() - [Row(MAX(age)=5, MAX(height)=85)] + [Row(max(age)=5, max(height)=85)] """ @df_varargs_api @@ -149,9 +149,9 @@ def min(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().min('age').collect() - [Row(MIN(age)=2)] + [Row(min(age)=2)] >>> df3.groupBy().min('age', 'height').collect() - [Row(MIN(age)=2, MIN(height)=80)] + [Row(min(age)=2, min(height)=80)] """ @df_varargs_api @@ -162,9 +162,9 @@ def sum(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().sum('age').collect() - [Row(SUM(age)=7)] + [Row(sum(age)=7)] >>> df3.groupBy().sum('age', 'height').collect() - [Row(SUM(age)=7, SUM(height)=165)] + [Row(sum(age)=7, sum(height)=165)] """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2ab5cb666fbcd..8420c54f7c335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -150,6 +150,7 @@ object HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ + // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -159,19 +160,19 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + case b @ BinaryOperator(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + case b @ BinaryOperator(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryExpression(StringNaN, right @ FloatType()) => + case b @ BinaryOperator(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryExpression(left @ FloatType(), StringNaN) => + case b @ BinaryOperator(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryExpression(left @ StringNaN, StringNaN) => + case b @ BinaryOperator(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -245,12 +246,12 @@ object HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryExpressions. + // Also widen types for BinaryOperator. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -478,7 +479,7 @@ object HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala new file mode 100644 index 0000000000000..450fc4165f93b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.DataType + + +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait AutoCastInputTypes { self: Expression => + + def inputTypes: Seq[DataType] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e18a3118945e8..cafbbafdca207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,17 +119,6 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns a string representation of this expression that does not have developer centric - * debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -154,71 +143,40 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess -} - -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable - - override def toString: String = s"($left $symbol $right)" /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f accepts two variable names and returns Java code to compute the output. + * Returns a user-facing string representation of this expression's name. + * This should usually match the name of the function in SQL. */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def prettyName: String = getClass.getSimpleName.toLowerCase /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Returns a user-facing string representation of this expression, i.e. does not have developer + * centric debugging information like the expression id. */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } -} -private[sql] object BinaryExpression { - def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def toString: String = prettyName + children.mkString("(", ",", ")") } + +/** + * A leaf expression, i.e. one without any child expressions. + */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } + +/** + * An expression with one input and one output. The output is by default evaluated to null + * if the input is evaluated to null. + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -265,39 +223,76 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } + /** - * An trait that gets mixin to define the expected input types of an expression. + * An expression with two inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. */ -trait ExpectsInputTypes { self: Expression => +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + override def foldable: Boolean = left.foldable && right.foldable + + override def nullable: Boolean = left.nullable || right.nullable /** - * Expected input types from child expressions. The i-th position in the returned seq indicates - * the type requirement for the i-th child. + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. * - * The possible values at each position are: - * 1. a specific data type, e.g. LongType, StringType. - * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. - * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + * @param f accepts two variable names and returns Java code to compute the output. */ - def inputTypes: Seq[Any] + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } - override def checkInputDataTypes(): TypeCheckResult = { - // We will do the type checking in `HiveTypeCoercion`, so always returning success here. - TypeCheckResult.TypeCheckSuccess + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } } + /** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. + * An expression that has two inputs that are expected to the be same type. If the two inputs have + * different types, the analyzer will find the tightest common type and do the proper type casting. */ -trait AutoCastInputTypes { self: Expression => +abstract class BinaryOperator extends BinaryExpression { + self: Product => - def inputTypes: Seq[DataType] + def symbol: String - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess - } + override def toString: String = s"($left $symbol $right)" +} + + +private[sql] object BinaryOperator { + def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ebabb6f117851..caf021b016a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index a9fc54c548f49..64e07bd2a17db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -94,7 +94,6 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() @@ -128,7 +127,6 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -162,7 +160,6 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -390,6 +387,8 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def prettyName: String = "avg" + override def nullable: Boolean = true override def dataType: DataType = child.dataType match { @@ -401,8 +400,6 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -494,8 +491,6 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString: String = s"SUM($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5363b3556886a..4fbf4c87009c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def toString: String = s"positive($child)" + override def prettyName: String = "positive" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,8 +69,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { - override def toString: String = s"Abs($child)" - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -79,10 +77,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryExpression { +abstract class BinaryArithmetic extends BinaryOperator { self: Product => - override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -360,7 +357,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - override def toString: String = s"MaxOf($left, $right)" + + override def symbol: String = "max" + override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -413,5 +412,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def toString: String = s"MinOf($left, $right)" + override def symbol: String = "min" + override def prettyName: String = symbol } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5def57b067424..67e7dc4ec8b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString: String = s"Array(${children.mkString(",")})" + override def prettyName: String = "array" } /** @@ -71,4 +71,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } + + override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 78be2824347d7..145d323a9f0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,8 +38,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a777f77add2db..34df89a163895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,7 +120,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -169,7 +169,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryExpression with Predicate { +abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index daa9f4403ffab..5d51a4ca65332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,8 +137,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType - override def symbol: String = "++=" - override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 4cbfc4e084948..b020f2bbc5818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -75,8 +75,6 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "LIKE" - // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -101,14 +99,16 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -134,9 +134,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase() - - override def toString: String = s"Upper($child)" + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -148,9 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase() - - override def toString: String = s"Lower($child)" + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -178,8 +174,6 @@ trait StringComparison extends AutoCastInputTypes { } } - override def symbol: String = nodeName - override def toString: String = s"$nodeName($left, $right)" } @@ -284,12 +278,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } - - override def toString: String = len match { - // TODO: This is broken because max is not an integer value. - case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" - case _ => s"SUBSTR($str, $pos, $len)" - } } /** @@ -304,9 +292,9 @@ case class StringLength(child: Expression) extends UnaryExpression with AutoCast if (string == null) null else string.asInstanceOf[UTF8String].length } - override def toString: String = s"length($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } + + override def prettyName: String = "length" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bda217935cb05..86792f0217572 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryExpression => actual.append(b.symbol); + case b: BinaryOperator => actual.append(b.symbol); case l: Literal => actual.append(l.toString); }